mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 18:41:08 +02:00
fix(mana-sync): enable row-level security on sync_changes
Defense-in-depth on top of the existing application-level WHERE clauses:
- Migrate() now ENABLE + FORCE row level security on sync_changes and
installs a policy that gates rows on current_setting('app.current_user_id').
FORCE makes the policy apply to the table owner too, so the application
role used by mana-sync cannot bypass it regardless of grants.
- New withUser(ctx, userID, fn) helper opens a transaction and calls
set_config('app.current_user_id', userID, true) before running fn.
Empty userIDs are rejected up-front so an unauthenticated request can
never reach the database with an empty RLS scope (which would match
every row).
- RecordChange / GetChangesSince / GetAllChangesSince all run inside
withUser. WITH CHECK on the policy double-validates the user_id column
on insert against the active session, so a future code path that
forgets the WHERE clause cannot leak data.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
28942abede
commit
a9529bcf1b
1 changed files with 124 additions and 75 deletions
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
|
|
@ -33,7 +34,12 @@ func (s *Store) Close() {
|
|||
s.pool.Close()
|
||||
}
|
||||
|
||||
// Migrate creates the sync_changes table if it doesn't exist.
|
||||
// Migrate creates the sync_changes table and enables row-level security.
|
||||
//
|
||||
// Defense-in-depth: every query also passes WHERE user_id = $1, but RLS makes
|
||||
// it impossible for a future query (or a query injection) to read or write
|
||||
// across user boundaries. The policy reads `app.current_user_id` from the
|
||||
// session config — store callers wrap their work in withUser() which sets it.
|
||||
func (s *Store) Migrate(ctx context.Context) error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS sync_changes (
|
||||
|
|
@ -57,13 +63,55 @@ func (s *Store) Migrate(ctx context.Context) error {
|
|||
|
||||
CREATE INDEX IF NOT EXISTS idx_sync_changes_since
|
||||
ON sync_changes (user_id, app_id, table_name, created_at);
|
||||
|
||||
ALTER TABLE sync_changes ENABLE ROW LEVEL SECURITY;
|
||||
-- FORCE makes RLS apply even to the table owner so that the application
|
||||
-- role used by mana-sync cannot bypass policies, regardless of grants.
|
||||
ALTER TABLE sync_changes FORCE ROW LEVEL SECURITY;
|
||||
|
||||
DROP POLICY IF EXISTS sync_changes_user_isolation ON sync_changes;
|
||||
CREATE POLICY sync_changes_user_isolation ON sync_changes
|
||||
USING (user_id = current_setting('app.current_user_id', true))
|
||||
WITH CHECK (user_id = current_setting('app.current_user_id', true));
|
||||
`
|
||||
|
||||
_, err := s.pool.Exec(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// RecordChange stores a client change in the database.
|
||||
// withUser runs fn inside a transaction scoped to the given user_id.
|
||||
// All RLS-protected reads and writes performed via the supplied tx will be
|
||||
// confined to rows owned by userID. The session-local app.current_user_id
|
||||
// setting is reset automatically when the transaction ends.
|
||||
//
|
||||
// Empty userIDs are rejected up-front so an unauthenticated request can never
|
||||
// reach the database with an empty RLS scope (which would match every row).
|
||||
func (s *Store) withUser(ctx context.Context, userID string, fn func(pgx.Tx) error) error {
|
||||
if userID == "" {
|
||||
return fmt.Errorf("withUser: empty userID")
|
||||
}
|
||||
|
||||
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback(ctx) }()
|
||||
|
||||
// set_config(name, value, is_local=true) is the parameterized form of
|
||||
// SET LOCAL — SET LOCAL itself does not accept bind parameters.
|
||||
if _, err := tx.Exec(ctx, "SELECT set_config('app.current_user_id', $1, true)", userID); err != nil {
|
||||
return fmt.Errorf("set rls user: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// RecordChange stores a client change in the database. The insert is performed
|
||||
// inside an RLS-scoped transaction so the user_id column is double-checked
|
||||
// against the policy on the way in — a mismatched user_id would fail WITH CHECK.
|
||||
func (s *Store) RecordChange(ctx context.Context, appID, tableName, recordID, userID, op, clientID string, data map[string]any, fieldTimestamps map[string]string) error {
|
||||
dataJSON, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
|
@ -75,13 +123,14 @@ func (s *Store) RecordChange(ctx context.Context, appID, tableName, recordID, us
|
|||
return fmt.Errorf("marshal field_timestamps: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO sync_changes (app_id, table_name, record_id, user_id, op, data, field_timestamps, client_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
|
||||
_, err = s.pool.Exec(ctx, query, appID, tableName, recordID, userID, op, dataJSON, ftJSON, clientID)
|
||||
return err
|
||||
return s.withUser(ctx, userID, func(tx pgx.Tx) error {
|
||||
query := `
|
||||
INSERT INTO sync_changes (app_id, table_name, record_id, user_id, op, data, field_timestamps, client_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
_, err := tx.Exec(ctx, query, appID, tableName, recordID, userID, op, dataJSON, ftJSON, clientID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// GetChangesSince returns changes for a user+app+table since a given timestamp,
|
||||
|
|
@ -93,46 +142,46 @@ func (s *Store) GetChangesSince(ctx context.Context, userID, appID, tableName, s
|
|||
sinceTime = time.Unix(0, 0)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, table_name, record_id, op, data, field_timestamps, client_id, created_at
|
||||
FROM sync_changes
|
||||
WHERE user_id = $1 AND app_id = $2 AND table_name = $3
|
||||
AND created_at > $4 AND client_id != $5
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $6
|
||||
`
|
||||
|
||||
rows, err := s.pool.Query(ctx, query, userID, appID, tableName, sinceTime, excludeClientID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var changes []ChangeRow
|
||||
for rows.Next() {
|
||||
var c ChangeRow
|
||||
var dataJSON, ftJSON []byte
|
||||
|
||||
err := rows.Scan(&c.ID, &c.TableName, &c.RecordID, &c.Op, &dataJSON, &ftJSON, &c.ClientID, &c.CreatedAt)
|
||||
err = s.withUser(ctx, userID, func(tx pgx.Tx) error {
|
||||
query := `
|
||||
SELECT id, table_name, record_id, op, data, field_timestamps, client_id, created_at
|
||||
FROM sync_changes
|
||||
WHERE user_id = $1 AND app_id = $2 AND table_name = $3
|
||||
AND created_at > $4 AND client_id != $5
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $6
|
||||
`
|
||||
rows, err := tx.Query(ctx, query, userID, appID, tableName, sinceTime, excludeClientID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if dataJSON != nil {
|
||||
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
|
||||
for rows.Next() {
|
||||
var c ChangeRow
|
||||
var dataJSON, ftJSON []byte
|
||||
|
||||
if err := rows.Scan(&c.ID, &c.TableName, &c.RecordID, &c.Op, &dataJSON, &ftJSON, &c.ClientID, &c.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if ftJSON != nil {
|
||||
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
|
||||
|
||||
if dataJSON != nil {
|
||||
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
|
||||
return fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
if ftJSON != nil {
|
||||
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
|
||||
return fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
changes = append(changes, c)
|
||||
}
|
||||
|
||||
changes = append(changes, c)
|
||||
}
|
||||
|
||||
return changes, rows.Err()
|
||||
return rows.Err()
|
||||
})
|
||||
return changes, err
|
||||
}
|
||||
|
||||
// GetAllChangesSince returns changes across all tables for a user+app.
|
||||
|
|
@ -142,46 +191,46 @@ func (s *Store) GetAllChangesSince(ctx context.Context, userID, appID, since, ex
|
|||
sinceTime = time.Unix(0, 0)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, table_name, record_id, op, data, field_timestamps, client_id, created_at
|
||||
FROM sync_changes
|
||||
WHERE user_id = $1 AND app_id = $2
|
||||
AND created_at > $3 AND client_id != $4
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 5000
|
||||
`
|
||||
|
||||
rows, err := s.pool.Query(ctx, query, userID, appID, sinceTime, excludeClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var changes []ChangeRow
|
||||
for rows.Next() {
|
||||
var c ChangeRow
|
||||
var dataJSON, ftJSON []byte
|
||||
|
||||
err := rows.Scan(&c.ID, &c.TableName, &c.RecordID, &c.Op, &dataJSON, &ftJSON, &c.ClientID, &c.CreatedAt)
|
||||
err = s.withUser(ctx, userID, func(tx pgx.Tx) error {
|
||||
query := `
|
||||
SELECT id, table_name, record_id, op, data, field_timestamps, client_id, created_at
|
||||
FROM sync_changes
|
||||
WHERE user_id = $1 AND app_id = $2
|
||||
AND created_at > $3 AND client_id != $4
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 5000
|
||||
`
|
||||
rows, err := tx.Query(ctx, query, userID, appID, sinceTime, excludeClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if dataJSON != nil {
|
||||
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
|
||||
for rows.Next() {
|
||||
var c ChangeRow
|
||||
var dataJSON, ftJSON []byte
|
||||
|
||||
if err := rows.Scan(&c.ID, &c.TableName, &c.RecordID, &c.Op, &dataJSON, &ftJSON, &c.ClientID, &c.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if ftJSON != nil {
|
||||
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
|
||||
|
||||
if dataJSON != nil {
|
||||
if err := json.Unmarshal(dataJSON, &c.Data); err != nil {
|
||||
return fmt.Errorf("unmarshal data for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
if ftJSON != nil {
|
||||
if err := json.Unmarshal(ftJSON, &c.FieldTimestamps); err != nil {
|
||||
return fmt.Errorf("unmarshal field_timestamps for record %s: %w", c.RecordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
changes = append(changes, c)
|
||||
}
|
||||
|
||||
changes = append(changes, c)
|
||||
}
|
||||
|
||||
return changes, rows.Err()
|
||||
return rows.Err()
|
||||
})
|
||||
return changes, err
|
||||
}
|
||||
|
||||
// ChangeRow is a row from the sync_changes table.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue