security: enforce per-user data isolation in memory service

- Make user_id required on all request models with field validators
- Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak)
- Add bearer token auth on all endpoints except /health
- Add composite index on (user_id, room_id) for conversation_chunks
- Bot: guard query_chunks with sender check, pass room_id, send auth token
- Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Christian Gick
2026-03-08 13:45:15 +02:00
parent e584ce8ce0
commit 36c7e36456
3 changed files with 92 additions and 31 deletions

View File

@@ -1,5 +1,6 @@
import os
import logging
import secrets
import time
import hashlib
import base64
@@ -7,8 +8,8 @@ import base64
import asyncpg
import httpx
from cryptography.fernet import Fernet
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi import Depends, FastAPI, Header, HTTPException
from pydantic import BaseModel, field_validator
logger = logging.getLogger("memory-service")
logging.basicConfig(level=logging.INFO)
@@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
app = FastAPI(title="Memory Service")
pool: asyncpg.Pool | None = None
owner_pool: asyncpg.Pool | None = None
async def verify_token(authorization: str | None = Header(None)):
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
if not MEMORY_SERVICE_TOKEN:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
raise HTTPException(403, "Invalid token")
def _derive_user_key(user_id: str) -> bytes:
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
if not ENCRYPTION_KEY:
@@ -64,12 +76,26 @@ class StoreRequest(BaseModel):
fact: str
source_room: str = ""
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class QueryRequest(BaseModel):
user_id: str
query: str
top_k: int = 10
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkStoreRequest(BaseModel):
user_id: str
@@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel):
source_event_id: str = ""
original_ts: float = 0.0
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkQueryRequest(BaseModel):
user_id: str = ""
user_id: str # REQUIRED — no default
room_id: str = ""
query: str
top_k: int = 5
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkBulkStoreRequest(BaseModel):
chunks: list[ChunkStoreRequest]
@@ -166,6 +206,9 @@ async def _init_db():
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
""")
finally:
await owner_conn.close()
# Create restricted pool for all request handlers (RLS applies)
@@ -205,7 +248,7 @@ async def health():
@app.post("/memories/store")
async def store_memory(req: StoreRequest):
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
if not req.fact.strip():
raise HTTPException(400, "Empty fact")
@@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest):
@app.post("/memories/query")
async def query_memories(req: QueryRequest):
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
"""Embed query, return top-K similar facts for user."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
@@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest):
@app.delete("/memories/{user_id}")
async def delete_user_memories(user_id: str):
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
"""GDPR delete — remove all memories for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
@@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str):
@app.get("/memories/{user_id}")
async def list_user_memories(user_id: str):
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
"""List all memories for a user (for UI/debug)."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
@@ -314,7 +357,7 @@ async def list_user_memories(user_id: str):
@app.post("/chunks/store")
async def store_chunk(req: ChunkStoreRequest):
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
"""Store a conversation chunk with its summary embedding, encrypted."""
if not req.summary.strip():
raise HTTPException(400, "Empty summary")
@@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest):
@app.post("/chunks/query")
async def query_chunks(req: ChunkQueryRequest):
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
# Build WHERE clause dynamically based on provided filters
conditions = []
params: list = [vec_literal]
idx = 2
# Build WHERE clause — user_id is always required
conditions = [f"user_id = $2"]
params: list = [vec_literal, req.user_id]
idx = 3
if req.user_id:
conditions.append(f"user_id = ${idx}")
params.append(req.user_id)
idx += 1
if req.room_id:
conditions.append(f"room_id = ${idx}")
params.append(req.room_id)
idx += 1
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
where = f"WHERE {' AND '.join(conditions)}"
params.append(req.top_k)
# Determine user_id for RLS and decryption
rls_user = req.user_id if req.user_id else ""
async with pool.acquire() as conn:
if rls_user:
await _set_rls_user(conn, rls_user)
await _set_rls_user(conn, req.user_id)
rows = await conn.fetch(
f"""
@@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest):
@app.post("/chunks/bulk-store")
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
if not req.chunks:
return {"stored": 0}
@@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
@app.get("/chunks/{user_id}/count")
async def count_user_chunks(user_id: str):
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
"""Count conversation chunks for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)