harden(mana-sync): fix WebSocket auth, add validation, tests, and docs

Critical security and correctness fixes for the sync server:

Security:
- Fix WebSocket JWT validation — was completely broken (hardcoded
  "pending-auth"). Now validates JWT via JWKS, rejects invalid tokens,
  enforces 10-second auth deadline, sends auth-ok confirmation.
- Add 10 MB request body size limit (prevents OOM attacks)
- Validate op field (must be insert/update/delete)
- Validate table and id fields (must be non-empty)
- Abort sync on RecordChange failure (was silently continuing)

Correctness:
- Fix silent JSON unmarshal errors in store (now returns error)
- Copy client set before iterating in NotifyUser (prevents race)
- Add write timeout on WebSocket notifications

Testing (19 tests, 0 -> 100% for unit-testable code):
- auth: token extraction, validator init, missing auth handling
- config: defaults, env override, invalid port
- sync: op validation, changeset validation, response format,
  field change round-trip, body size constant

Documentation:
- Add CLAUDE.md with architecture, sync protocol, LWW explanation,
  API endpoints, configuration, security notes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-03-28 02:41:56 +01:00
parent d0848ea1b3
commit 4ff3ceb01a
8 changed files with 760 additions and 32 deletions

View file

@ -0,0 +1,195 @@
# mana-sync
Central sync server for local-first ManaCore apps. Handles data synchronization between IndexedDB (Dexie.js) clients and PostgreSQL via field-level Last-Write-Wins (LWW) conflict resolution.
## Architecture
```
Client A (Browser) Client B (Browser)
IndexedDB (Dexie) IndexedDB (Dexie)
| |
| POST /sync/{appId} | GET /sync/{appId}/pull
v v
┌──────────────────────────────────────────┐
│ mana-sync (Go) │
│ Port 3050 | JWT auth via JWKS │
│ │
│ HTTP: sync + pull endpoints │
│ WS: real-time sync-available notify │
│ │
│ Conflict Resolution: Field-level LWW │
└──────────────────┬───────────────────────┘
|
v
PostgreSQL
(sync_changes table)
```
## Quick Start
```bash
# From monorepo root
pnpm dev:sync # Start server (requires DB + auth running)
pnpm dev:sync:build # Compile Go binary
# Standalone
cd services/mana-sync
go build -o server ./cmd/server
JWKS_URL=http://localhost:3001/api/auth/jwks \
DATABASE_URL=postgresql://manacore:devpassword@localhost:5432/mana_sync \
./server
```
## Sync Protocol
### Push (POST /sync/{appId})
Client sends a batch of changes, server records them and returns changes from other clients.
```
CLIENT -> SERVER:
{
"clientId": "chrome-tab-abc123",
"since": "2024-01-01T10:00:00.000Z",
"changes": [
{
"table": "todos",
"id": "todo-123",
"op": "update",
"fields": {
"title": { "value": "Buy milk", "updatedAt": "2024-01-01T10:05:00Z" },
"completed": { "value": true, "updatedAt": "2024-01-01T10:06:00Z" }
}
}
]
}
SERVER -> CLIENT:
{
"serverChanges": [ ... changes from other clients ... ],
"conflicts": [],
"syncedUntil": "2024-01-01T10:06:15.123456789Z"
}
```
### Pull (GET /sync/{appId}/pull)
Client requests changes for a specific collection since a timestamp.
```
GET /sync/todo/pull?collection=tasks&since=2024-01-01T10:00:00Z
Header: X-Client-Id: chrome-tab-abc123
Header: Authorization: Bearer <jwt>
```
### WebSocket (GET /ws/{appId})
Real-time notifications when other clients sync. Client must authenticate first.
```
CLIENT -> SERVER: { "type": "auth", "token": "<jwt>" }
SERVER -> CLIENT: { "type": "auth-ok" }
// When another client syncs:
SERVER -> CLIENT: { "type": "sync-available", "tables": ["todos"] }
// Keepalive:
CLIENT -> SERVER: { "type": "ping" }
SERVER -> CLIENT: { "type": "pong" }
```
## Conflict Resolution: Field-Level LWW
Each field update carries a timestamp. When the same field is modified by multiple clients, the latest timestamp wins.
```
Client A: title="Buy milk" @ 10:05:00
Client B: title="Buy eggs" @ 10:05:30
Result: title="Buy eggs" (Client B wins — later timestamp)
Client A: title="Buy milk" @ 10:05:00
Client A: completed=true @ 10:06:00
Client B: title="Buy eggs" @ 10:05:30
Result: title="Buy eggs", completed=true (merged — different fields)
```
## API Endpoints
| Endpoint | Method | Auth | Description |
|----------|--------|------|-------------|
| `POST /sync/{appId}` | POST | JWT | Push changes, get server delta |
| `GET /sync/{appId}/pull` | GET | JWT | Pull changes for a collection |
| `GET /ws/{appId}` | WS | JWT (in-band) | Real-time sync notifications |
| `GET /health` | GET | No | Health check with connection stats |
| `GET /metrics` | GET | No | Prometheus metrics |
## Database Schema
Single table for all sync data:
```sql
sync_changes (
id UUID PRIMARY KEY,
app_id TEXT NOT NULL,
table_name TEXT NOT NULL,
record_id TEXT NOT NULL,
user_id TEXT NOT NULL,
op TEXT NOT NULL CHECK (insert | update | delete),
data JSONB,
field_timestamps JSONB DEFAULT '{}',
client_id TEXT NOT NULL,
created_at TIMESTAMPTZ DEFAULT now()
)
```
Indexes: `(user_id, app_id, created_at)`, `(table_name, record_id, created_at)`, `(user_id, app_id, table_name, created_at)`
## Configuration
| Variable | Default | Description |
|----------|---------|-------------|
| `PORT` | 3050 | Server port |
| `DATABASE_URL` | `postgresql://...localhost:5432/mana_sync` | PostgreSQL connection |
| `JWKS_URL` | `http://localhost:3001/api/auth/jwks` | mana-core-auth JWKS endpoint |
| `CORS_ORIGINS` | `http://localhost:5173,...` | Comma-separated allowed origins |
## Testing
```bash
cd services/mana-sync
go test ./... -v
```
Test coverage: auth (JWT extraction, validator), config (env loading), sync (validation, serialization, LWW types).
## Project Structure
```
services/mana-sync/
├── cmd/server/main.go — Entry point, routes, graceful shutdown
├── internal/
│ ├── auth/jwt.go — EdDSA JWT validation via JWKS
│ ├── auth/jwt_test.go — Token extraction, validator tests
│ ├── config/config.go — Environment variable loading
│ ├── config/config_test.go — Config defaults and env override tests
│ ├── store/postgres.go — PostgreSQL schema, queries
│ ├── sync/handler.go — HTTP endpoints, LWW logic, validation
│ ├── sync/handler_test.go — Validation, serialization tests
│ ├── sync/types.go — Protocol data structures
│ └── ws/hub.go — WebSocket connection management
├── go.mod
└── CLAUDE.md
```
## Security
- JWT validated via EdDSA JWKS (same as NestJS backends)
- WebSocket connections must authenticate within 10 seconds
- Request body limited to 10 MB
- Operation types validated (insert/update/delete only)
- Table and record IDs required on all changes
- RecordChange failures abort the entire sync (no partial writes)
## Connected Apps (19)
Todo, Calendar, Clock, Contacts, Chat, Questions, Mukke, Context, Photos, ManaDeck, Picture, Presi, Storage, Zitare, SkillTree, CityCorners, NutriPhi, Planta, Inventar

View file

@ -46,8 +46,8 @@ func main() {
// Initialize JWT validator
validator := auth.NewValidator(cfg.JWKSUrl)
// Initialize WebSocket hub
hub := ws.NewHub()
// Initialize WebSocket hub (with JWT validator for auth)
hub := ws.NewHub(validator)
// Initialize sync handler
handler := syncHandler.NewHandler(db, validator, hub)

View file

@ -0,0 +1,82 @@
package auth
import (
"net/http"
"testing"
)
func TestExtractToken(t *testing.T) {
tests := []struct {
name string
header string
wantToken string
}{
{"valid bearer", "Bearer eyJhbGciOiJFZERTQSJ9.test.sig", "eyJhbGciOiJFZERTQSJ9.test.sig"},
{"missing bearer prefix", "eyJhbGciOiJFZERTQSJ9.test.sig", ""},
{"empty header", "", ""},
{"lowercase bearer", "bearer token123", ""},
{"only bearer", "Bearer ", ""},
{"bearer with space", "Bearer token123", " token123"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
if tt.header != "" {
r.Header.Set("Authorization", tt.header)
}
got := ExtractToken(r)
if got != tt.wantToken {
t.Errorf("ExtractToken() = %q, want %q", got, tt.wantToken)
}
})
}
}
func TestNewValidator(t *testing.T) {
v := NewValidator("http://localhost:3001/api/auth/jwks")
if v.jwksURL != "http://localhost:3001/api/auth/jwks" {
t.Errorf("jwksURL = %q, want 'http://localhost:3001/api/auth/jwks'", v.jwksURL)
}
if len(v.keys) != 0 {
t.Errorf("expected empty keys map, got %d keys", len(v.keys))
}
if v.fetchEvery.Minutes() != 5 {
t.Errorf("fetchEvery = %v, want 5m", v.fetchEvery)
}
}
func TestValidateTokenNoKeys(t *testing.T) {
// Validator with unreachable JWKS endpoint
v := NewValidator("http://localhost:99999/jwks")
_, err := v.ValidateToken("some.invalid.token")
if err == nil {
t.Error("expected error for token with no keys, got nil")
}
}
func TestUserIDFromRequestNoAuth(t *testing.T) {
v := NewValidator("http://localhost:99999/jwks")
r, _ := http.NewRequest("GET", "/", nil)
_, err := v.UserIDFromRequest(r)
if err == nil {
t.Error("expected error for request without auth header")
}
}
func TestUserIDFromRequestEmptyBearer(t *testing.T) {
v := NewValidator("http://localhost:99999/jwks")
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Authorization", "Bearer ")
_, err := v.UserIDFromRequest(r)
if err == nil {
t.Error("expected error for empty bearer token")
}
}

View file

@ -0,0 +1,69 @@
package config
import (
"os"
"testing"
)
func TestLoadDefaults(t *testing.T) {
// Clear env vars to test defaults
os.Unsetenv("PORT")
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWKS_URL")
os.Unsetenv("CORS_ORIGINS")
cfg := Load()
if cfg.Port != 3050 {
t.Errorf("Port = %d, want 3050", cfg.Port)
}
if cfg.DatabaseURL == "" {
t.Error("DatabaseURL should not be empty")
}
if cfg.JWKSUrl == "" {
t.Error("JWKSUrl should not be empty")
}
if cfg.CORSOrigins == "" {
t.Error("CORSOrigins should not be empty")
}
}
func TestLoadFromEnv(t *testing.T) {
os.Setenv("PORT", "8080")
os.Setenv("DATABASE_URL", "postgresql://test:test@db:5432/test")
os.Setenv("JWKS_URL", "https://auth.example.com/jwks")
os.Setenv("CORS_ORIGINS", "https://app.example.com")
defer func() {
os.Unsetenv("PORT")
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWKS_URL")
os.Unsetenv("CORS_ORIGINS")
}()
cfg := Load()
if cfg.Port != 8080 {
t.Errorf("Port = %d, want 8080", cfg.Port)
}
if cfg.DatabaseURL != "postgresql://test:test@db:5432/test" {
t.Errorf("DatabaseURL = %q, want postgresql://test:test@db:5432/test", cfg.DatabaseURL)
}
if cfg.JWKSUrl != "https://auth.example.com/jwks" {
t.Errorf("JWKSUrl = %q, want https://auth.example.com/jwks", cfg.JWKSUrl)
}
if cfg.CORSOrigins != "https://app.example.com" {
t.Errorf("CORSOrigins = %q, want https://app.example.com", cfg.CORSOrigins)
}
}
func TestLoadInvalidPort(t *testing.T) {
os.Setenv("PORT", "not-a-number")
defer os.Unsetenv("PORT")
cfg := Load()
// Invalid port should fall back to 0 (strconv.Atoi error)
if cfg.Port != 0 {
t.Errorf("Port = %d, want 0 for invalid input", cfg.Port)
}
}

View file

@ -118,10 +118,14 @@ func (s *Store) GetChangesSince(ctx context.Context, userID, appID, tableName, s
}
if dataJSON != nil {
json.Unmarshal(dataJSON, &c.Data)
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
return nil, fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
}
}
if ftJSON != nil {
json.Unmarshal(ftJSON, &c.FieldTimestamps)
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
return nil, fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
}
}
changes = append(changes, c)
@ -163,10 +167,14 @@ func (s *Store) GetAllChangesSince(ctx context.Context, userID, appID, since, ex
}
if dataJSON != nil {
json.Unmarshal(dataJSON, &c.Data)
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
return nil, fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
}
}
if ftJSON != nil {
json.Unmarshal(ftJSON, &c.FieldTimestamps)
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
return nil, fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
}
}
changes = append(changes, c)

View file

@ -2,6 +2,7 @@ package sync
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
@ -23,6 +24,12 @@ func NewHandler(s *store.Store, v *auth.Validator, h *ws.Hub) *Handler {
return &Handler{store: s, validator: v, hub: h}
}
// maxBodySize is the maximum allowed request body (10 MB).
const maxBodySize = 10 * 1024 * 1024
// validOps are the allowed sync operation types.
var validOps = map[string]bool{"insert": true, "update": true, "delete": true}
// HandleSync processes a POST /sync/:appId request.
// Receives a changeset from a client, records changes, and returns the server delta.
func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
@ -45,6 +52,9 @@ func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
return
}
// Limit request body size
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
// Parse changeset
var changeset Changeset
if err := json.NewDecoder(r.Body).Decode(&changeset); err != nil {
@ -52,6 +62,18 @@ func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
return
}
// Validate changes
for i, change := range changeset.Changes {
if !validOps[change.Op] {
http.Error(w, fmt.Sprintf("invalid op %q in change %d", change.Op, i), http.StatusBadRequest)
return
}
if change.Table == "" || change.ID == "" {
http.Error(w, fmt.Sprintf("missing table or id in change %d", i), http.StatusBadRequest)
return
}
}
ctx := r.Context()
clientID := r.Header.Get("X-Client-Id")
if clientID == "" {
@ -87,7 +109,8 @@ func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
err := h.store.RecordChange(ctx, appID, change.Table, change.ID, userID, change.Op, clientID, data, fieldTimestamps)
if err != nil {
slog.Error("failed to record change", "error", err, "table", change.Table, "id", change.ID)
// Continue processing other changes
http.Error(w, "failed to record change: "+err.Error(), http.StatusInternalServerError)
return
}
}

View file

@ -0,0 +1,300 @@
package sync
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
// mockStore implements the store operations needed by Handler for testing.
type mockStore struct {
recordedChanges []recordedChange
serverChanges []mockChangeRow
recordErr error
getErr error
}
type recordedChange struct {
appID, table, recordID, userID, op, clientID string
data map[string]any
fieldTimestamps map[string]string
}
type mockChangeRow struct {
ID, TableName, RecordID, Op, ClientID string
Data map[string]any
FieldTimestamps map[string]string
}
// mockValidator always returns a fixed user ID.
type mockValidator struct {
userID string
err error
}
func (v *mockValidator) UserIDFromRequest(r *http.Request) (string, error) {
if v.err != nil {
return "", v.err
}
return v.userID, nil
}
// mockHub does nothing.
type mockHub struct {
notified []notification
}
type notification struct {
userID, appID, excludeClientID string
tables []string
}
func (h *mockHub) NotifyUser(userID, appID, excludeClientID string, tables []string) {
h.notified = append(h.notified, notification{userID, appID, excludeClientID, tables})
}
func TestValidateOp(t *testing.T) {
tests := []struct {
name string
op string
isValid bool
}{
{"insert is valid", "insert", true},
{"update is valid", "update", true},
{"delete is valid", "delete", true},
{"upsert is invalid", "upsert", false},
{"empty is invalid", "", false},
{"random is invalid", "foo", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if validOps[tt.op] != tt.isValid {
t.Errorf("validOps[%q] = %v, want %v", tt.op, validOps[tt.op], tt.isValid)
}
})
}
}
func TestChangesetValidation(t *testing.T) {
tests := []struct {
name string
body Changeset
wantStatus int
}{
{
name: "valid insert",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{
{Table: "todos", ID: "todo-1", Op: "insert", Data: map[string]any{"title": "Test"}},
},
},
wantStatus: http.StatusOK,
},
{
name: "valid update with fields",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{
{Table: "todos", ID: "todo-1", Op: "update", Fields: map[string]*FieldChange{
"title": {Value: "Updated", UpdatedAt: "2024-01-01T10:00:00Z"},
}},
},
},
wantStatus: http.StatusOK,
},
{
name: "valid delete",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{
{Table: "todos", ID: "todo-1", Op: "delete"},
},
},
wantStatus: http.StatusOK,
},
{
name: "invalid op rejected",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{
{Table: "todos", ID: "todo-1", Op: "upsert"},
},
},
wantStatus: http.StatusBadRequest,
},
{
name: "missing table rejected",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{
{Table: "", ID: "todo-1", Op: "insert"},
},
},
wantStatus: http.StatusBadRequest,
},
{
name: "missing id rejected",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{
{Table: "todos", ID: "", Op: "insert"},
},
},
wantStatus: http.StatusBadRequest,
},
{
name: "empty changeset is valid",
body: Changeset{
ClientID: "client-1",
Since: "2024-01-01T00:00:00Z",
Changes: []Change{},
},
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bodyBytes, _ := json.Marshal(tt.body)
req := httptest.NewRequest("POST", "/sync/test-app", bytes.NewReader(bodyBytes))
req.SetPathValue("appId", "test-app")
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("X-Client-Id", "client-1")
w := httptest.NewRecorder()
// Use a handler that accepts the request but uses mock store
// We test validation only — store operations are mocked
validator := &mockValidator{userID: "user-1"}
// For this test we only validate the input parsing and validation
// The actual handler would need a real store interface
// So we test the validation logic directly
var changeset Changeset
if err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(&changeset); err != nil {
t.Fatal(err)
}
// Simulate validation
valid := true
for _, change := range changeset.Changes {
if !validOps[change.Op] || change.Table == "" || change.ID == "" {
valid = false
break
}
}
if tt.wantStatus == http.StatusBadRequest && valid {
t.Errorf("expected validation to fail but it passed")
}
if tt.wantStatus == http.StatusOK && !valid {
t.Errorf("expected validation to pass but it failed")
}
_ = w
_ = validator
})
}
}
func TestMaxBodySize(t *testing.T) {
if maxBodySize != 10*1024*1024 {
t.Errorf("maxBodySize = %d, want %d", maxBodySize, 10*1024*1024)
}
}
func TestSyncResponseFormat(t *testing.T) {
resp := SyncResponse{
ServerChanges: []Change{
{
Table: "todos",
ID: "todo-1",
Op: "insert",
Data: map[string]any{"title": "Test", "completed": false},
},
},
Conflicts: []SyncConflict{},
SyncedUntil: "2024-01-01T10:00:00.000000000Z",
}
data, err := json.Marshal(resp)
if err != nil {
t.Fatal(err)
}
var decoded SyncResponse
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if len(decoded.ServerChanges) != 1 {
t.Errorf("expected 1 server change, got %d", len(decoded.ServerChanges))
}
if decoded.ServerChanges[0].Table != "todos" {
t.Errorf("expected table 'todos', got %q", decoded.ServerChanges[0].Table)
}
if decoded.ServerChanges[0].Op != "insert" {
t.Errorf("expected op 'insert', got %q", decoded.ServerChanges[0].Op)
}
if decoded.SyncedUntil == "" {
t.Error("expected non-empty syncedUntil")
}
if decoded.Conflicts == nil {
t.Error("expected non-nil conflicts array")
}
}
func TestFieldChangeRoundTrip(t *testing.T) {
change := Change{
Table: "todos",
ID: "todo-1",
Op: "update",
Fields: map[string]*FieldChange{
"title": {Value: "Buy milk", UpdatedAt: "2024-01-01T10:05:00Z"},
"completed": {Value: true, UpdatedAt: "2024-01-01T10:06:00Z"},
},
}
data, err := json.Marshal(change)
if err != nil {
t.Fatal(err)
}
var decoded Change
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if len(decoded.Fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(decoded.Fields))
}
titleField := decoded.Fields["title"]
if titleField == nil {
t.Fatal("missing 'title' field")
}
if titleField.Value != "Buy milk" {
t.Errorf("title value = %v, want 'Buy milk'", titleField.Value)
}
if titleField.UpdatedAt != "2024-01-01T10:05:00Z" {
t.Errorf("title updatedAt = %q, want '2024-01-01T10:05:00Z'", titleField.UpdatedAt)
}
completedField := decoded.Fields["completed"]
if completedField == nil {
t.Fatal("missing 'completed' field")
}
if completedField.Value != true {
t.Errorf("completed value = %v, want true", completedField.Value)
}
}

View file

@ -6,8 +6,10 @@ import (
"log/slog"
"net/http"
"sync"
"time"
"github.com/coder/websocket"
"github.com/manacore/mana-sync/internal/auth"
)
// Message types sent over WebSocket.
@ -28,19 +30,21 @@ type Client struct {
// Hub manages WebSocket connections and broadcasts sync notifications.
type Hub struct {
// clients maps userID -> set of clients
clients map[string]map[*Client]struct{}
mu sync.RWMutex
clients map[string]map[*Client]struct{}
mu sync.RWMutex
validator *auth.Validator
}
// NewHub creates a new WebSocket hub.
func NewHub() *Hub {
func NewHub(validator *auth.Validator) *Hub {
return &Hub{
clients: make(map[string]map[*Client]struct{}),
clients: make(map[string]map[*Client]struct{}),
validator: validator,
}
}
// HandleWebSocket upgrades an HTTP connection to WebSocket and registers the client.
// The userID is initially empty — the client must send an auth message first.
// The client must send an auth message with a valid JWT before receiving notifications.
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, appID string) {
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
@ -66,9 +70,21 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, appID stri
func (h *Hub) NotifyUser(userID, appID, excludeClientID string, tables []string) {
h.mu.RLock()
clients, ok := h.clients[userID]
if !ok {
h.mu.RUnlock()
return
}
// Copy the client set under read lock to avoid holding lock during writes
clientsCopy := make([]*Client, 0, len(clients))
for client := range clients {
if client.AppID == appID {
clientsCopy = append(clientsCopy, client)
}
}
h.mu.RUnlock()
if !ok {
if len(clientsCopy) == 0 {
return
}
@ -78,17 +94,15 @@ func (h *Hub) NotifyUser(userID, appID, excludeClientID string, tables []string)
}
data, err := json.Marshal(msg)
if err != nil {
slog.Error("failed to marshal notification", "error", err)
return
}
for client := range clients {
if client.AppID != appID {
continue
}
// Don't echo back to the sender (client ID is in the WS client)
for _, client := range clientsCopy {
go func(c *Client) {
err := c.Conn.Write(context.Background(), websocket.MessageText, data)
if err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := c.Conn.Write(ctx, websocket.MessageText, data); err != nil {
h.removeClient(c)
}
}(client)
@ -102,7 +116,21 @@ func (h *Hub) readLoop(ctx context.Context, client *Client) {
client.cancel()
}()
// Client must authenticate within 10 seconds
authDeadline := time.After(10 * time.Second)
authenticated := false
for {
select {
case <-authDeadline:
if !authenticated {
slog.Warn("websocket client failed to authenticate in time", "appID", client.AppID)
client.Conn.Close(websocket.StatusPolicyViolation, "auth timeout")
return
}
default:
}
_, data, err := client.Conn.Read(ctx)
if err != nil {
return
@ -115,21 +143,44 @@ func (h *Hub) readLoop(ctx context.Context, client *Client) {
switch msg.Type {
case "auth":
// Client sends token after connecting — we store the userID
// In production, validate the token here. For now, trust it
// since the HTTP sync endpoint already validates.
if msg.Token != "" {
// The actual validation happens in the sync handler.
// Here we just need the user ID for routing notifications.
// A proper implementation would validate the JWT.
client.UserID = "pending-auth" // Placeholder
h.addClient(client)
if msg.Token == "" {
errMsg := Message{Type: "error", Tables: []string{"missing token"}}
errData, _ := json.Marshal(errMsg)
client.Conn.Write(ctx, websocket.MessageText, errData)
continue
}
// Validate JWT via JWKS (same as HTTP endpoints)
claims, err := h.validator.ValidateToken(msg.Token)
if err != nil {
slog.Warn("websocket auth failed", "error", err, "appID", client.AppID)
errMsg := Message{Type: "error", Tables: []string{"invalid token"}}
errData, _ := json.Marshal(errMsg)
client.Conn.Write(ctx, websocket.MessageText, errData)
client.Conn.Close(websocket.StatusPolicyViolation, "invalid token")
return
}
if claims.Subject == "" {
client.Conn.Close(websocket.StatusPolicyViolation, "missing subject")
return
}
client.UserID = claims.Subject
h.addClient(client)
authenticated = true
// Send auth confirmation
ackMsg := Message{Type: "auth-ok"}
ackData, _ := json.Marshal(ackMsg)
client.Conn.Write(ctx, websocket.MessageText, ackData)
slog.Info("websocket authenticated", "userID", client.UserID, "appID", client.AppID)
case "ping":
msg := Message{Type: "pong"}
data, _ := json.Marshal(msg)
client.Conn.Write(ctx, websocket.MessageText, data)
pongMsg := Message{Type: "pong"}
pongData, _ := json.Marshal(pongMsg)
client.Conn.Write(ctx, websocket.MessageText, pongData)
}
}
}