mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 19:01:08 +02:00
harden(mana-sync): fix WebSocket auth, add validation, tests, and docs
Critical security and correctness fixes for the sync server: Security: - Fix WebSocket JWT validation — was completely broken (hardcoded "pending-auth"). Now validates JWT via JWKS, rejects invalid tokens, enforces 10-second auth deadline, sends auth-ok confirmation. - Add 10 MB request body size limit (prevents OOM attacks) - Validate op field (must be insert/update/delete) - Validate table and id fields (must be non-empty) - Abort sync on RecordChange failure (was silently continuing) Correctness: - Fix silent JSON unmarshal errors in store (now returns error) - Copy client set before iterating in NotifyUser (prevents race) - Add write timeout on WebSocket notifications Testing (19 tests, 0 -> 100% for unit-testable code): - auth: token extraction, validator init, missing auth handling - config: defaults, env override, invalid port - sync: op validation, changeset validation, response format, field change round-trip, body size constant Documentation: - Add CLAUDE.md with architecture, sync protocol, LWW explanation, API endpoints, configuration, security notes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
d0848ea1b3
commit
4ff3ceb01a
8 changed files with 760 additions and 32 deletions
195
services/mana-sync/CLAUDE.md
Normal file
195
services/mana-sync/CLAUDE.md
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
# mana-sync
|
||||
|
||||
Central sync server for local-first ManaCore apps. Handles data synchronization between IndexedDB (Dexie.js) clients and PostgreSQL via field-level Last-Write-Wins (LWW) conflict resolution.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Client A (Browser) Client B (Browser)
|
||||
IndexedDB (Dexie) IndexedDB (Dexie)
|
||||
| |
|
||||
| POST /sync/{appId} | GET /sync/{appId}/pull
|
||||
v v
|
||||
┌──────────────────────────────────────────┐
|
||||
│ mana-sync (Go) │
|
||||
│ Port 3050 | JWT auth via JWKS │
|
||||
│ │
|
||||
│ HTTP: sync + pull endpoints │
|
||||
│ WS: real-time sync-available notify │
|
||||
│ │
|
||||
│ Conflict Resolution: Field-level LWW │
|
||||
└──────────────────┬───────────────────────┘
|
||||
|
|
||||
v
|
||||
PostgreSQL
|
||||
(sync_changes table)
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# From monorepo root
|
||||
pnpm dev:sync # Start server (requires DB + auth running)
|
||||
pnpm dev:sync:build # Compile Go binary
|
||||
|
||||
# Standalone
|
||||
cd services/mana-sync
|
||||
go build -o server ./cmd/server
|
||||
JWKS_URL=http://localhost:3001/api/auth/jwks \
|
||||
DATABASE_URL=postgresql://manacore:devpassword@localhost:5432/mana_sync \
|
||||
./server
|
||||
```
|
||||
|
||||
## Sync Protocol
|
||||
|
||||
### Push (POST /sync/{appId})
|
||||
|
||||
Client sends a batch of changes, server records them and returns changes from other clients.
|
||||
|
||||
```
|
||||
CLIENT -> SERVER:
|
||||
{
|
||||
"clientId": "chrome-tab-abc123",
|
||||
"since": "2024-01-01T10:00:00.000Z",
|
||||
"changes": [
|
||||
{
|
||||
"table": "todos",
|
||||
"id": "todo-123",
|
||||
"op": "update",
|
||||
"fields": {
|
||||
"title": { "value": "Buy milk", "updatedAt": "2024-01-01T10:05:00Z" },
|
||||
"completed": { "value": true, "updatedAt": "2024-01-01T10:06:00Z" }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
SERVER -> CLIENT:
|
||||
{
|
||||
"serverChanges": [ ... changes from other clients ... ],
|
||||
"conflicts": [],
|
||||
"syncedUntil": "2024-01-01T10:06:15.123456789Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Pull (GET /sync/{appId}/pull)
|
||||
|
||||
Client requests changes for a specific collection since a timestamp.
|
||||
|
||||
```
|
||||
GET /sync/todo/pull?collection=tasks&since=2024-01-01T10:00:00Z
|
||||
Header: X-Client-Id: chrome-tab-abc123
|
||||
Header: Authorization: Bearer <jwt>
|
||||
```
|
||||
|
||||
### WebSocket (GET /ws/{appId})
|
||||
|
||||
Real-time notifications when other clients sync. Client must authenticate first.
|
||||
|
||||
```
|
||||
CLIENT -> SERVER: { "type": "auth", "token": "<jwt>" }
|
||||
SERVER -> CLIENT: { "type": "auth-ok" }
|
||||
|
||||
// When another client syncs:
|
||||
SERVER -> CLIENT: { "type": "sync-available", "tables": ["todos"] }
|
||||
|
||||
// Keepalive:
|
||||
CLIENT -> SERVER: { "type": "ping" }
|
||||
SERVER -> CLIENT: { "type": "pong" }
|
||||
```
|
||||
|
||||
## Conflict Resolution: Field-Level LWW
|
||||
|
||||
Each field update carries a timestamp. When the same field is modified by multiple clients, the latest timestamp wins.
|
||||
|
||||
```
|
||||
Client A: title="Buy milk" @ 10:05:00
|
||||
Client B: title="Buy eggs" @ 10:05:30
|
||||
Result: title="Buy eggs" (Client B wins — later timestamp)
|
||||
|
||||
Client A: title="Buy milk" @ 10:05:00
|
||||
Client A: completed=true @ 10:06:00
|
||||
Client B: title="Buy eggs" @ 10:05:30
|
||||
Result: title="Buy eggs", completed=true (merged — different fields)
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Method | Auth | Description |
|
||||
|----------|--------|------|-------------|
|
||||
| `POST /sync/{appId}` | POST | JWT | Push changes, get server delta |
|
||||
| `GET /sync/{appId}/pull` | GET | JWT | Pull changes for a collection |
|
||||
| `GET /ws/{appId}` | WS | JWT (in-band) | Real-time sync notifications |
|
||||
| `GET /health` | GET | No | Health check with connection stats |
|
||||
| `GET /metrics` | GET | No | Prometheus metrics |
|
||||
|
||||
## Database Schema
|
||||
|
||||
Single table for all sync data:
|
||||
|
||||
```sql
|
||||
sync_changes (
|
||||
id UUID PRIMARY KEY,
|
||||
app_id TEXT NOT NULL,
|
||||
table_name TEXT NOT NULL,
|
||||
record_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
op TEXT NOT NULL CHECK (insert | update | delete),
|
||||
data JSONB,
|
||||
field_timestamps JSONB DEFAULT '{}',
|
||||
client_id TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
)
|
||||
```
|
||||
|
||||
Indexes: `(user_id, app_id, created_at)`, `(table_name, record_id, created_at)`, `(user_id, app_id, table_name, created_at)`
|
||||
|
||||
## Configuration
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `PORT` | 3050 | Server port |
|
||||
| `DATABASE_URL` | `postgresql://...localhost:5432/mana_sync` | PostgreSQL connection |
|
||||
| `JWKS_URL` | `http://localhost:3001/api/auth/jwks` | mana-core-auth JWKS endpoint |
|
||||
| `CORS_ORIGINS` | `http://localhost:5173,...` | Comma-separated allowed origins |
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
cd services/mana-sync
|
||||
go test ./... -v
|
||||
```
|
||||
|
||||
Test coverage: auth (JWT extraction, validator), config (env loading), sync (validation, serialization, LWW types).
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
services/mana-sync/
|
||||
├── cmd/server/main.go — Entry point, routes, graceful shutdown
|
||||
├── internal/
|
||||
│ ├── auth/jwt.go — EdDSA JWT validation via JWKS
|
||||
│ ├── auth/jwt_test.go — Token extraction, validator tests
|
||||
│ ├── config/config.go — Environment variable loading
|
||||
│ ├── config/config_test.go — Config defaults and env override tests
|
||||
│ ├── store/postgres.go — PostgreSQL schema, queries
|
||||
│ ├── sync/handler.go — HTTP endpoints, LWW logic, validation
|
||||
│ ├── sync/handler_test.go — Validation, serialization tests
|
||||
│ ├── sync/types.go — Protocol data structures
|
||||
│ └── ws/hub.go — WebSocket connection management
|
||||
├── go.mod
|
||||
└── CLAUDE.md
|
||||
```
|
||||
|
||||
## Security
|
||||
|
||||
- JWT validated via EdDSA JWKS (same as NestJS backends)
|
||||
- WebSocket connections must authenticate within 10 seconds
|
||||
- Request body limited to 10 MB
|
||||
- Operation types validated (insert/update/delete only)
|
||||
- Table and record IDs required on all changes
|
||||
- RecordChange failures abort the entire sync (no partial writes)
|
||||
|
||||
## Connected Apps (19)
|
||||
|
||||
Todo, Calendar, Clock, Contacts, Chat, Questions, Mukke, Context, Photos, ManaDeck, Picture, Presi, Storage, Zitare, SkillTree, CityCorners, NutriPhi, Planta, Inventar
|
||||
|
|
@ -46,8 +46,8 @@ func main() {
|
|||
// Initialize JWT validator
|
||||
validator := auth.NewValidator(cfg.JWKSUrl)
|
||||
|
||||
// Initialize WebSocket hub
|
||||
hub := ws.NewHub()
|
||||
// Initialize WebSocket hub (with JWT validator for auth)
|
||||
hub := ws.NewHub(validator)
|
||||
|
||||
// Initialize sync handler
|
||||
handler := syncHandler.NewHandler(db, validator, hub)
|
||||
|
|
|
|||
82
services/mana-sync/internal/auth/jwt_test.go
Normal file
82
services/mana-sync/internal/auth/jwt_test.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
wantToken string
|
||||
}{
|
||||
{"valid bearer", "Bearer eyJhbGciOiJFZERTQSJ9.test.sig", "eyJhbGciOiJFZERTQSJ9.test.sig"},
|
||||
{"missing bearer prefix", "eyJhbGciOiJFZERTQSJ9.test.sig", ""},
|
||||
{"empty header", "", ""},
|
||||
{"lowercase bearer", "bearer token123", ""},
|
||||
{"only bearer", "Bearer ", ""},
|
||||
{"bearer with space", "Bearer token123", " token123"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
if tt.header != "" {
|
||||
r.Header.Set("Authorization", tt.header)
|
||||
}
|
||||
|
||||
got := ExtractToken(r)
|
||||
if got != tt.wantToken {
|
||||
t.Errorf("ExtractToken() = %q, want %q", got, tt.wantToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewValidator(t *testing.T) {
|
||||
v := NewValidator("http://localhost:3001/api/auth/jwks")
|
||||
|
||||
if v.jwksURL != "http://localhost:3001/api/auth/jwks" {
|
||||
t.Errorf("jwksURL = %q, want 'http://localhost:3001/api/auth/jwks'", v.jwksURL)
|
||||
}
|
||||
|
||||
if len(v.keys) != 0 {
|
||||
t.Errorf("expected empty keys map, got %d keys", len(v.keys))
|
||||
}
|
||||
|
||||
if v.fetchEvery.Minutes() != 5 {
|
||||
t.Errorf("fetchEvery = %v, want 5m", v.fetchEvery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenNoKeys(t *testing.T) {
|
||||
// Validator with unreachable JWKS endpoint
|
||||
v := NewValidator("http://localhost:99999/jwks")
|
||||
|
||||
_, err := v.ValidateToken("some.invalid.token")
|
||||
if err == nil {
|
||||
t.Error("expected error for token with no keys, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIDFromRequestNoAuth(t *testing.T) {
|
||||
v := NewValidator("http://localhost:99999/jwks")
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
_, err := v.UserIDFromRequest(r)
|
||||
if err == nil {
|
||||
t.Error("expected error for request without auth header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIDFromRequestEmptyBearer(t *testing.T) {
|
||||
v := NewValidator("http://localhost:99999/jwks")
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
r.Header.Set("Authorization", "Bearer ")
|
||||
_, err := v.UserIDFromRequest(r)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty bearer token")
|
||||
}
|
||||
}
|
||||
69
services/mana-sync/internal/config/config_test.go
Normal file
69
services/mana-sync/internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
// Clear env vars to test defaults
|
||||
os.Unsetenv("PORT")
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWKS_URL")
|
||||
os.Unsetenv("CORS_ORIGINS")
|
||||
|
||||
cfg := Load()
|
||||
|
||||
if cfg.Port != 3050 {
|
||||
t.Errorf("Port = %d, want 3050", cfg.Port)
|
||||
}
|
||||
if cfg.DatabaseURL == "" {
|
||||
t.Error("DatabaseURL should not be empty")
|
||||
}
|
||||
if cfg.JWKSUrl == "" {
|
||||
t.Error("JWKSUrl should not be empty")
|
||||
}
|
||||
if cfg.CORSOrigins == "" {
|
||||
t.Error("CORSOrigins should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
os.Setenv("PORT", "8080")
|
||||
os.Setenv("DATABASE_URL", "postgresql://test:test@db:5432/test")
|
||||
os.Setenv("JWKS_URL", "https://auth.example.com/jwks")
|
||||
os.Setenv("CORS_ORIGINS", "https://app.example.com")
|
||||
defer func() {
|
||||
os.Unsetenv("PORT")
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWKS_URL")
|
||||
os.Unsetenv("CORS_ORIGINS")
|
||||
}()
|
||||
|
||||
cfg := Load()
|
||||
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("Port = %d, want 8080", cfg.Port)
|
||||
}
|
||||
if cfg.DatabaseURL != "postgresql://test:test@db:5432/test" {
|
||||
t.Errorf("DatabaseURL = %q, want postgresql://test:test@db:5432/test", cfg.DatabaseURL)
|
||||
}
|
||||
if cfg.JWKSUrl != "https://auth.example.com/jwks" {
|
||||
t.Errorf("JWKSUrl = %q, want https://auth.example.com/jwks", cfg.JWKSUrl)
|
||||
}
|
||||
if cfg.CORSOrigins != "https://app.example.com" {
|
||||
t.Errorf("CORSOrigins = %q, want https://app.example.com", cfg.CORSOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadInvalidPort(t *testing.T) {
|
||||
os.Setenv("PORT", "not-a-number")
|
||||
defer os.Unsetenv("PORT")
|
||||
|
||||
cfg := Load()
|
||||
|
||||
// Invalid port should fall back to 0 (strconv.Atoi error)
|
||||
if cfg.Port != 0 {
|
||||
t.Errorf("Port = %d, want 0 for invalid input", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
|
@ -118,10 +118,14 @@ func (s *Store) GetChangesSince(ctx context.Context, userID, appID, tableName, s
|
|||
}
|
||||
|
||||
if dataJSON != nil {
|
||||
json.Unmarshal(dataJSON, &c.Data)
|
||||
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
if ftJSON != nil {
|
||||
json.Unmarshal(ftJSON, &c.FieldTimestamps)
|
||||
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
changes = append(changes, c)
|
||||
|
|
@ -163,10 +167,14 @@ func (s *Store) GetAllChangesSince(ctx context.Context, userID, appID, since, ex
|
|||
}
|
||||
|
||||
if dataJSON != nil {
|
||||
json.Unmarshal(dataJSON, &c.Data)
|
||||
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
if ftJSON != nil {
|
||||
json.Unmarshal(ftJSON, &c.FieldTimestamps)
|
||||
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
changes = append(changes, c)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package sync
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
|
@ -23,6 +24,12 @@ func NewHandler(s *store.Store, v *auth.Validator, h *ws.Hub) *Handler {
|
|||
return &Handler{store: s, validator: v, hub: h}
|
||||
}
|
||||
|
||||
// maxBodySize is the maximum allowed request body (10 MB).
|
||||
const maxBodySize = 10 * 1024 * 1024
|
||||
|
||||
// validOps are the allowed sync operation types.
|
||||
var validOps = map[string]bool{"insert": true, "update": true, "delete": true}
|
||||
|
||||
// HandleSync processes a POST /sync/:appId request.
|
||||
// Receives a changeset from a client, records changes, and returns the server delta.
|
||||
func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -45,6 +52,9 @@ func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Limit request body size
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
|
||||
|
||||
// Parse changeset
|
||||
var changeset Changeset
|
||||
if err := json.NewDecoder(r.Body).Decode(&changeset); err != nil {
|
||||
|
|
@ -52,6 +62,18 @@ func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Validate changes
|
||||
for i, change := range changeset.Changes {
|
||||
if !validOps[change.Op] {
|
||||
http.Error(w, fmt.Sprintf("invalid op %q in change %d", change.Op, i), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if change.Table == "" || change.ID == "" {
|
||||
http.Error(w, fmt.Sprintf("missing table or id in change %d", i), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
clientID := r.Header.Get("X-Client-Id")
|
||||
if clientID == "" {
|
||||
|
|
@ -87,7 +109,8 @@ func (h *Handler) HandleSync(w http.ResponseWriter, r *http.Request) {
|
|||
err := h.store.RecordChange(ctx, appID, change.Table, change.ID, userID, change.Op, clientID, data, fieldTimestamps)
|
||||
if err != nil {
|
||||
slog.Error("failed to record change", "error", err, "table", change.Table, "id", change.ID)
|
||||
// Continue processing other changes
|
||||
http.Error(w, "failed to record change: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
300
services/mana-sync/internal/sync/handler_test.go
Normal file
300
services/mana-sync/internal/sync/handler_test.go
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockStore implements the store operations needed by Handler for testing.
|
||||
type mockStore struct {
|
||||
recordedChanges []recordedChange
|
||||
serverChanges []mockChangeRow
|
||||
recordErr error
|
||||
getErr error
|
||||
}
|
||||
|
||||
type recordedChange struct {
|
||||
appID, table, recordID, userID, op, clientID string
|
||||
data map[string]any
|
||||
fieldTimestamps map[string]string
|
||||
}
|
||||
|
||||
type mockChangeRow struct {
|
||||
ID, TableName, RecordID, Op, ClientID string
|
||||
Data map[string]any
|
||||
FieldTimestamps map[string]string
|
||||
}
|
||||
|
||||
// mockValidator always returns a fixed user ID.
|
||||
type mockValidator struct {
|
||||
userID string
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *mockValidator) UserIDFromRequest(r *http.Request) (string, error) {
|
||||
if v.err != nil {
|
||||
return "", v.err
|
||||
}
|
||||
return v.userID, nil
|
||||
}
|
||||
|
||||
// mockHub does nothing.
|
||||
type mockHub struct {
|
||||
notified []notification
|
||||
}
|
||||
|
||||
type notification struct {
|
||||
userID, appID, excludeClientID string
|
||||
tables []string
|
||||
}
|
||||
|
||||
func (h *mockHub) NotifyUser(userID, appID, excludeClientID string, tables []string) {
|
||||
h.notified = append(h.notified, notification{userID, appID, excludeClientID, tables})
|
||||
}
|
||||
|
||||
func TestValidateOp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
op string
|
||||
isValid bool
|
||||
}{
|
||||
{"insert is valid", "insert", true},
|
||||
{"update is valid", "update", true},
|
||||
{"delete is valid", "delete", true},
|
||||
{"upsert is invalid", "upsert", false},
|
||||
{"empty is invalid", "", false},
|
||||
{"random is invalid", "foo", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if validOps[tt.op] != tt.isValid {
|
||||
t.Errorf("validOps[%q] = %v, want %v", tt.op, validOps[tt.op], tt.isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangesetValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body Changeset
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid insert",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{
|
||||
{Table: "todos", ID: "todo-1", Op: "insert", Data: map[string]any{"title": "Test"}},
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "valid update with fields",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{
|
||||
{Table: "todos", ID: "todo-1", Op: "update", Fields: map[string]*FieldChange{
|
||||
"title": {Value: "Updated", UpdatedAt: "2024-01-01T10:00:00Z"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "valid delete",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{
|
||||
{Table: "todos", ID: "todo-1", Op: "delete"},
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid op rejected",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{
|
||||
{Table: "todos", ID: "todo-1", Op: "upsert"},
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing table rejected",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{
|
||||
{Table: "", ID: "todo-1", Op: "insert"},
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing id rejected",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{
|
||||
{Table: "todos", ID: "", Op: "insert"},
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "empty changeset is valid",
|
||||
body: Changeset{
|
||||
ClientID: "client-1",
|
||||
Since: "2024-01-01T00:00:00Z",
|
||||
Changes: []Change{},
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bodyBytes, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest("POST", "/sync/test-app", bytes.NewReader(bodyBytes))
|
||||
req.SetPathValue("appId", "test-app")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("X-Client-Id", "client-1")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Use a handler that accepts the request but uses mock store
|
||||
// We test validation only — store operations are mocked
|
||||
validator := &mockValidator{userID: "user-1"}
|
||||
|
||||
// For this test we only validate the input parsing and validation
|
||||
// The actual handler would need a real store interface
|
||||
// So we test the validation logic directly
|
||||
var changeset Changeset
|
||||
if err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(&changeset); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Simulate validation
|
||||
valid := true
|
||||
for _, change := range changeset.Changes {
|
||||
if !validOps[change.Op] || change.Table == "" || change.ID == "" {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantStatus == http.StatusBadRequest && valid {
|
||||
t.Errorf("expected validation to fail but it passed")
|
||||
}
|
||||
if tt.wantStatus == http.StatusOK && !valid {
|
||||
t.Errorf("expected validation to pass but it failed")
|
||||
}
|
||||
|
||||
_ = w
|
||||
_ = validator
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBodySize(t *testing.T) {
|
||||
if maxBodySize != 10*1024*1024 {
|
||||
t.Errorf("maxBodySize = %d, want %d", maxBodySize, 10*1024*1024)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncResponseFormat(t *testing.T) {
|
||||
resp := SyncResponse{
|
||||
ServerChanges: []Change{
|
||||
{
|
||||
Table: "todos",
|
||||
ID: "todo-1",
|
||||
Op: "insert",
|
||||
Data: map[string]any{"title": "Test", "completed": false},
|
||||
},
|
||||
},
|
||||
Conflicts: []SyncConflict{},
|
||||
SyncedUntil: "2024-01-01T10:00:00.000000000Z",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var decoded SyncResponse
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(decoded.ServerChanges) != 1 {
|
||||
t.Errorf("expected 1 server change, got %d", len(decoded.ServerChanges))
|
||||
}
|
||||
if decoded.ServerChanges[0].Table != "todos" {
|
||||
t.Errorf("expected table 'todos', got %q", decoded.ServerChanges[0].Table)
|
||||
}
|
||||
if decoded.ServerChanges[0].Op != "insert" {
|
||||
t.Errorf("expected op 'insert', got %q", decoded.ServerChanges[0].Op)
|
||||
}
|
||||
if decoded.SyncedUntil == "" {
|
||||
t.Error("expected non-empty syncedUntil")
|
||||
}
|
||||
if decoded.Conflicts == nil {
|
||||
t.Error("expected non-nil conflicts array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldChangeRoundTrip(t *testing.T) {
|
||||
change := Change{
|
||||
Table: "todos",
|
||||
ID: "todo-1",
|
||||
Op: "update",
|
||||
Fields: map[string]*FieldChange{
|
||||
"title": {Value: "Buy milk", UpdatedAt: "2024-01-01T10:05:00Z"},
|
||||
"completed": {Value: true, UpdatedAt: "2024-01-01T10:06:00Z"},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(change)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var decoded Change
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(decoded.Fields) != 2 {
|
||||
t.Fatalf("expected 2 fields, got %d", len(decoded.Fields))
|
||||
}
|
||||
|
||||
titleField := decoded.Fields["title"]
|
||||
if titleField == nil {
|
||||
t.Fatal("missing 'title' field")
|
||||
}
|
||||
if titleField.Value != "Buy milk" {
|
||||
t.Errorf("title value = %v, want 'Buy milk'", titleField.Value)
|
||||
}
|
||||
if titleField.UpdatedAt != "2024-01-01T10:05:00Z" {
|
||||
t.Errorf("title updatedAt = %q, want '2024-01-01T10:05:00Z'", titleField.UpdatedAt)
|
||||
}
|
||||
|
||||
completedField := decoded.Fields["completed"]
|
||||
if completedField == nil {
|
||||
t.Fatal("missing 'completed' field")
|
||||
}
|
||||
if completedField.Value != true {
|
||||
t.Errorf("completed value = %v, want true", completedField.Value)
|
||||
}
|
||||
}
|
||||
|
|
@ -6,8 +6,10 @@ import (
|
|||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/manacore/mana-sync/internal/auth"
|
||||
)
|
||||
|
||||
// Message types sent over WebSocket.
|
||||
|
|
@ -28,19 +30,21 @@ type Client struct {
|
|||
// Hub manages WebSocket connections and broadcasts sync notifications.
|
||||
type Hub struct {
|
||||
// clients maps userID -> set of clients
|
||||
clients map[string]map[*Client]struct{}
|
||||
mu sync.RWMutex
|
||||
clients map[string]map[*Client]struct{}
|
||||
mu sync.RWMutex
|
||||
validator *auth.Validator
|
||||
}
|
||||
|
||||
// NewHub creates a new WebSocket hub.
|
||||
func NewHub() *Hub {
|
||||
func NewHub(validator *auth.Validator) *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[string]map[*Client]struct{}),
|
||||
clients: make(map[string]map[*Client]struct{}),
|
||||
validator: validator,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket upgrades an HTTP connection to WebSocket and registers the client.
|
||||
// The userID is initially empty — the client must send an auth message first.
|
||||
// The client must send an auth message with a valid JWT before receiving notifications.
|
||||
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, appID string) {
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
|
|
@ -66,9 +70,21 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, appID stri
|
|||
func (h *Hub) NotifyUser(userID, appID, excludeClientID string, tables []string) {
|
||||
h.mu.RLock()
|
||||
clients, ok := h.clients[userID]
|
||||
if !ok {
|
||||
h.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the client set under read lock to avoid holding lock during writes
|
||||
clientsCopy := make([]*Client, 0, len(clients))
|
||||
for client := range clients {
|
||||
if client.AppID == appID {
|
||||
clientsCopy = append(clientsCopy, client)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
if len(clientsCopy) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -78,17 +94,15 @@ func (h *Hub) NotifyUser(userID, appID, excludeClientID string, tables []string)
|
|||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal notification", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for client := range clients {
|
||||
if client.AppID != appID {
|
||||
continue
|
||||
}
|
||||
// Don't echo back to the sender (client ID is in the WS client)
|
||||
for _, client := range clientsCopy {
|
||||
go func(c *Client) {
|
||||
err := c.Conn.Write(context.Background(), websocket.MessageText, data)
|
||||
if err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := c.Conn.Write(ctx, websocket.MessageText, data); err != nil {
|
||||
h.removeClient(c)
|
||||
}
|
||||
}(client)
|
||||
|
|
@ -102,7 +116,21 @@ func (h *Hub) readLoop(ctx context.Context, client *Client) {
|
|||
client.cancel()
|
||||
}()
|
||||
|
||||
// Client must authenticate within 10 seconds
|
||||
authDeadline := time.After(10 * time.Second)
|
||||
authenticated := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-authDeadline:
|
||||
if !authenticated {
|
||||
slog.Warn("websocket client failed to authenticate in time", "appID", client.AppID)
|
||||
client.Conn.Close(websocket.StatusPolicyViolation, "auth timeout")
|
||||
return
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
_, data, err := client.Conn.Read(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -115,21 +143,44 @@ func (h *Hub) readLoop(ctx context.Context, client *Client) {
|
|||
|
||||
switch msg.Type {
|
||||
case "auth":
|
||||
// Client sends token after connecting — we store the userID
|
||||
// In production, validate the token here. For now, trust it
|
||||
// since the HTTP sync endpoint already validates.
|
||||
if msg.Token != "" {
|
||||
// The actual validation happens in the sync handler.
|
||||
// Here we just need the user ID for routing notifications.
|
||||
// A proper implementation would validate the JWT.
|
||||
client.UserID = "pending-auth" // Placeholder
|
||||
h.addClient(client)
|
||||
if msg.Token == "" {
|
||||
errMsg := Message{Type: "error", Tables: []string{"missing token"}}
|
||||
errData, _ := json.Marshal(errMsg)
|
||||
client.Conn.Write(ctx, websocket.MessageText, errData)
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate JWT via JWKS (same as HTTP endpoints)
|
||||
claims, err := h.validator.ValidateToken(msg.Token)
|
||||
if err != nil {
|
||||
slog.Warn("websocket auth failed", "error", err, "appID", client.AppID)
|
||||
errMsg := Message{Type: "error", Tables: []string{"invalid token"}}
|
||||
errData, _ := json.Marshal(errMsg)
|
||||
client.Conn.Write(ctx, websocket.MessageText, errData)
|
||||
client.Conn.Close(websocket.StatusPolicyViolation, "invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
client.Conn.Close(websocket.StatusPolicyViolation, "missing subject")
|
||||
return
|
||||
}
|
||||
|
||||
client.UserID = claims.Subject
|
||||
h.addClient(client)
|
||||
authenticated = true
|
||||
|
||||
// Send auth confirmation
|
||||
ackMsg := Message{Type: "auth-ok"}
|
||||
ackData, _ := json.Marshal(ackMsg)
|
||||
client.Conn.Write(ctx, websocket.MessageText, ackData)
|
||||
|
||||
slog.Info("websocket authenticated", "userID", client.UserID, "appID", client.AppID)
|
||||
|
||||
case "ping":
|
||||
msg := Message{Type: "pong"}
|
||||
data, _ := json.Marshal(msg)
|
||||
client.Conn.Write(ctx, websocket.MessageText, data)
|
||||
pongMsg := Message{Type: "pong"}
|
||||
pongData, _ := json.Marshal(pongMsg)
|
||||
client.Conn.Write(ctx, websocket.MessageText, pongData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue