security: enforce per-user data isolation in memory service
- Make user_id required on all request models with field validators - Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak) - Add bearer token auth on all endpoints except /health - Add composite index on (user_id, room_id) for conversation_chunks - Bot: guard query_chunks with sender check, pass room_id, send auth token - Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
import hashlib
|
||||
import base64
|
||||
@@ -7,8 +8,8 @@ import base64
|
||||
import asyncpg
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger("memory-service")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
||||
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
||||
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
||||
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||
|
||||
app = FastAPI(title="Memory Service")
|
||||
pool: asyncpg.Pool | None = None
|
||||
owner_pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
async def verify_token(authorization: str | None = Header(None)):
|
||||
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
|
||||
if not MEMORY_SERVICE_TOKEN:
|
||||
return
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or invalid Authorization header")
|
||||
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
|
||||
raise HTTPException(403, "Invalid token")
|
||||
|
||||
|
||||
def _derive_user_key(user_id: str) -> bytes:
|
||||
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
||||
if not ENCRYPTION_KEY:
|
||||
@@ -64,12 +76,26 @@ class StoreRequest(BaseModel):
|
||||
fact: str
|
||||
source_room: str = ""
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
user_id: str
|
||||
query: str
|
||||
top_k: int = 10
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkStoreRequest(BaseModel):
|
||||
user_id: str
|
||||
@@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel):
|
||||
source_event_id: str = ""
|
||||
original_ts: float = 0.0
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
user_id: str = ""
|
||||
user_id: str # REQUIRED — no default
|
||||
room_id: str = ""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkBulkStoreRequest(BaseModel):
|
||||
chunks: list[ChunkStoreRequest]
|
||||
@@ -166,6 +206,9 @@ async def _init_db():
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
|
||||
""")
|
||||
finally:
|
||||
await owner_conn.close()
|
||||
# Create restricted pool for all request handlers (RLS applies)
|
||||
@@ -205,7 +248,7 @@ async def health():
|
||||
|
||||
|
||||
@app.post("/memories/store")
|
||||
async def store_memory(req: StoreRequest):
|
||||
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
|
||||
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
||||
if not req.fact.strip():
|
||||
raise HTTPException(400, "Empty fact")
|
||||
@@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest):
|
||||
|
||||
|
||||
@app.post("/memories/query")
|
||||
async def query_memories(req: QueryRequest):
|
||||
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
||||
"""Embed query, return top-K similar facts for user."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
@@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest):
|
||||
|
||||
|
||||
@app.delete("/memories/{user_id}")
|
||||
async def delete_user_memories(user_id: str):
|
||||
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""GDPR delete — remove all memories for a user."""
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
@@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str):
|
||||
|
||||
|
||||
@app.get("/memories/{user_id}")
|
||||
async def list_user_memories(user_id: str):
|
||||
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""List all memories for a user (for UI/debug)."""
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
@@ -314,7 +357,7 @@ async def list_user_memories(user_id: str):
|
||||
|
||||
|
||||
@app.post("/chunks/store")
|
||||
async def store_chunk(req: ChunkStoreRequest):
|
||||
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
|
||||
"""Store a conversation chunk with its summary embedding, encrypted."""
|
||||
if not req.summary.strip():
|
||||
raise HTTPException(400, "Empty summary")
|
||||
@@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest):
|
||||
|
||||
|
||||
@app.post("/chunks/query")
|
||||
async def query_chunks(req: ChunkQueryRequest):
|
||||
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
|
||||
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
# Build WHERE clause dynamically based on provided filters
|
||||
conditions = []
|
||||
params: list = [vec_literal]
|
||||
idx = 2
|
||||
# Build WHERE clause — user_id is always required
|
||||
conditions = [f"user_id = $2"]
|
||||
params: list = [vec_literal, req.user_id]
|
||||
idx = 3
|
||||
|
||||
if req.user_id:
|
||||
conditions.append(f"user_id = ${idx}")
|
||||
params.append(req.user_id)
|
||||
idx += 1
|
||||
if req.room_id:
|
||||
conditions.append(f"room_id = ${idx}")
|
||||
params.append(req.room_id)
|
||||
idx += 1
|
||||
|
||||
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
where = f"WHERE {' AND '.join(conditions)}"
|
||||
params.append(req.top_k)
|
||||
|
||||
# Determine user_id for RLS and decryption
|
||||
rls_user = req.user_id if req.user_id else ""
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
if rls_user:
|
||||
await _set_rls_user(conn, rls_user)
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
@@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest):
|
||||
|
||||
|
||||
@app.post("/chunks/bulk-store")
|
||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
|
||||
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
||||
if not req.chunks:
|
||||
return {"stored": 0}
|
||||
@@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
||||
|
||||
|
||||
@app.get("/chunks/{user_id}/count")
|
||||
async def count_user_chunks(user_id: str):
|
||||
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
|
||||
"""Count conversation chunks for a user."""
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
|
||||
Reference in New Issue
Block a user