From 36c7e36456178d74e9cd09a52e50e2a544de0b01 Mon Sep 17 00:00:00 2001 From: Christian Gick Date: Sun, 8 Mar 2026 13:45:15 +0200 Subject: [PATCH] security: enforce per-user data isolation in memory service - Make user_id required on all request models with field validators - Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak) - Add bearer token auth on all endpoints except /health - Add composite index on (user_id, room_id) for conversation_chunks - Bot: guard query_chunks with sender check, pass room_id, send auth token - Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service Co-Authored-By: Claude Opus 4.6 --- bot.py | 36 +++++++++++++++--- docker-compose.yml | 2 + memory-service/main.py | 85 +++++++++++++++++++++++++++++------------- 3 files changed, 92 insertions(+), 31 deletions(-) diff --git a/bot.py b/bot.py index 1c58b69..4049fef 100644 --- a/bot.py +++ b/bot.py @@ -70,6 +70,7 @@ LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet") MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090") +MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "") CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "") CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "") CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "") @@ -444,10 +445,17 @@ class DocumentRAG: class MemoryClient: """Async HTTP client for the memory-service.""" - def __init__(self, base_url: str): + def __init__(self, base_url: str, token: str = ""): self.base_url = base_url.rstrip("/") + self.token = token self.enabled = bool(base_url) + def _headers(self) -> dict: + h = {} + if self.token: + h["Authorization"] = f"Bearer {self.token}" + return h + async def store(self, user_id: str, fact: str, source_room: str = ""): if not self.enabled: return @@ -456,6 +464,7 @@ class MemoryClient: await client.post( f"{self.base_url}/memories/store", json={"user_id": user_id, "fact": fact, "source_room": source_room}, + headers=self._headers(), ) except Exception: logger.warning("Memory store failed", exc_info=True) @@ -468,6 +477,7 @@ class MemoryClient: resp = await client.post( f"{self.base_url}/memories/query", json={"user_id": user_id, "query": query, "top_k": top_k}, + headers=self._headers(), ) resp.raise_for_status() return resp.json().get("results", []) @@ -480,7 +490,10 @@ class MemoryClient: return 0 try: async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.delete(f"{self.base_url}/memories/{user_id}") + resp = await client.delete( + f"{self.base_url}/memories/{user_id}", + headers=self._headers(), + ) resp.raise_for_status() return resp.json().get("deleted", 0) except Exception: @@ -492,7 +505,10 @@ class MemoryClient: return [] try: async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.get(f"{self.base_url}/memories/{user_id}") + resp = await client.get( + f"{self.base_url}/memories/{user_id}", + headers=self._headers(), + ) resp.raise_for_status() return resp.json().get("memories", []) except Exception: @@ -512,19 +528,24 @@ class MemoryClient: "chunk_text": chunk_text, "summary": summary, "source_event_id": source_event_id, "original_ts": original_ts, }, + headers=self._headers(), ) except Exception: logger.warning("Chunk store failed", exc_info=True) - async def query_chunks(self, query: str, user_id: str = "", room_id: str = "", + async def query_chunks(self, query: str, user_id: str, room_id: str = "", top_k: int = 5) -> list[dict]: if not self.enabled: return [] + if not user_id: + logger.error("query_chunks called with empty user_id — returning empty") + return [] try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( f"{self.base_url}/chunks/query", json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k}, + headers=self._headers(), ) resp.raise_for_status() return resp.json().get("results", []) @@ -961,7 +982,7 @@ class Bot: self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY, rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN) self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY) - self.memory = MemoryClient(MEMORY_SERVICE_URL) + self.memory = MemoryClient(MEMORY_SERVICE_URL, token=MEMORY_SERVICE_TOKEN) self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY) self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status @@ -2092,7 +2113,10 @@ class Bot: memory_context = self._format_memories(memories) # Query relevant conversation chunks (RAG over chat history) - chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5) + if sender: + chunks = await self.memory.query_chunks(search_query, user_id=sender, room_id=room.room_id, top_k=5) + else: + chunks = [] chunk_context = self._format_chunks(chunks) # Include room document context (PDFs, Confluence pages, images uploaded to room) diff --git a/docker-compose.yml b/docker-compose.yml index a276a66..218f97c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,6 +20,7 @@ services: - LITELLM_API_KEY - DEFAULT_MODEL - MEMORY_SERVICE_URL=http://memory-service:8090 + - MEMORY_SERVICE_TOKEN - PORTAL_URL - BOT_API_KEY volumes: @@ -60,6 +61,7 @@ services: LITELLM_BASE_URL: ${LITELLM_BASE_URL} LITELLM_API_KEY: ${LITELLM_MASTER_KEY} EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small} + MEMORY_SERVICE_TOKEN: ${MEMORY_SERVICE_TOKEN:-} depends_on: memory-db: condition: service_healthy diff --git a/memory-service/main.py b/memory-service/main.py index 8dde4ec..2ea0c46 100644 --- a/memory-service/main.py +++ b/memory-service/main.py @@ -1,5 +1,6 @@ import os import logging +import secrets import time import hashlib import base64 @@ -7,8 +8,8 @@ import base64 import asyncpg import httpx from cryptography.fernet import Fernet -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from fastapi import Depends, FastAPI, Header, HTTPException +from pydantic import BaseModel, field_validator logger = logging.getLogger("memory-service") logging.basicConfig(level=logging.INFO) @@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") +MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "") app = FastAPI(title="Memory Service") pool: asyncpg.Pool | None = None owner_pool: asyncpg.Pool | None = None +async def verify_token(authorization: str | None = Header(None)): + """Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode).""" + if not MEMORY_SERVICE_TOKEN: + return + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(401, "Missing or invalid Authorization header") + if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN): + raise HTTPException(403, "Invalid token") + + def _derive_user_key(user_id: str) -> bytes: """Derive a per-user Fernet key from master key + user_id via HMAC-SHA256.""" if not ENCRYPTION_KEY: @@ -64,12 +76,26 @@ class StoreRequest(BaseModel): fact: str source_room: str = "" + @field_validator('user_id') + @classmethod + def user_id_not_empty(cls, v): + if not v or not v.strip(): + raise ValueError("user_id is required") + return v.strip() + class QueryRequest(BaseModel): user_id: str query: str top_k: int = 10 + @field_validator('user_id') + @classmethod + def user_id_not_empty(cls, v): + if not v or not v.strip(): + raise ValueError("user_id is required") + return v.strip() + class ChunkStoreRequest(BaseModel): user_id: str @@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel): source_event_id: str = "" original_ts: float = 0.0 + @field_validator('user_id') + @classmethod + def user_id_not_empty(cls, v): + if not v or not v.strip(): + raise ValueError("user_id is required") + return v.strip() + class ChunkQueryRequest(BaseModel): - user_id: str = "" + user_id: str # REQUIRED — no default room_id: str = "" query: str top_k: int = 5 + @field_validator('user_id') + @classmethod + def user_id_not_empty(cls, v): + if not v or not v.strip(): + raise ValueError("user_id is required") + return v.strip() + class ChunkBulkStoreRequest(BaseModel): chunks: list[ChunkStoreRequest] @@ -166,6 +206,9 @@ async def _init_db(): await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id) + """) finally: await owner_conn.close() # Create restricted pool for all request handlers (RLS applies) @@ -205,7 +248,7 @@ async def health(): @app.post("/memories/store") -async def store_memory(req: StoreRequest): +async def store_memory(req: StoreRequest, _: None = Depends(verify_token)): """Embed fact, deduplicate by cosine similarity, insert encrypted.""" if not req.fact.strip(): raise HTTPException(400, "Empty fact") @@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest): @app.post("/memories/query") -async def query_memories(req: QueryRequest): +async def query_memories(req: QueryRequest, _: None = Depends(verify_token)): """Embed query, return top-K similar facts for user.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" @@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest): @app.delete("/memories/{user_id}") -async def delete_user_memories(user_id: str): +async def delete_user_memories(user_id: str, _: None = Depends(verify_token)): """GDPR delete — remove all memories for a user.""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) @@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str): @app.get("/memories/{user_id}") -async def list_user_memories(user_id: str): +async def list_user_memories(user_id: str, _: None = Depends(verify_token)): """List all memories for a user (for UI/debug).""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) @@ -314,7 +357,7 @@ async def list_user_memories(user_id: str): @app.post("/chunks/store") -async def store_chunk(req: ChunkStoreRequest): +async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)): """Store a conversation chunk with its summary embedding, encrypted.""" if not req.summary.strip(): raise HTTPException(400, "Empty summary") @@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest): @app.post("/chunks/query") -async def query_chunks(req: ChunkQueryRequest): +async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)): """Semantic search over conversation chunks. Filter by user_id and/or room_id.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" - # Build WHERE clause dynamically based on provided filters - conditions = [] - params: list = [vec_literal] - idx = 2 + # Build WHERE clause — user_id is always required + conditions = [f"user_id = $2"] + params: list = [vec_literal, req.user_id] + idx = 3 - if req.user_id: - conditions.append(f"user_id = ${idx}") - params.append(req.user_id) - idx += 1 if req.room_id: conditions.append(f"room_id = ${idx}") params.append(req.room_id) idx += 1 - where = f"WHERE {' AND '.join(conditions)}" if conditions else "" + where = f"WHERE {' AND '.join(conditions)}" params.append(req.top_k) - # Determine user_id for RLS and decryption - rls_user = req.user_id if req.user_id else "" - async with pool.acquire() as conn: - if rls_user: - await _set_rls_user(conn, rls_user) + await _set_rls_user(conn, req.user_id) rows = await conn.fetch( f""" @@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest): @app.post("/chunks/bulk-store") -async def bulk_store_chunks(req: ChunkBulkStoreRequest): +async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)): """Batch store conversation chunks. Embeds summaries in batches of 20.""" if not req.chunks: return {"stored": 0} @@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest): @app.get("/chunks/{user_id}/count") -async def count_user_chunks(user_id: str): +async def count_user_chunks(user_id: str, _: None = Depends(verify_token)): """Count conversation chunks for a user.""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id)