security: enforce per-user data isolation in memory service

- Make user_id required on all request models with field validators
- Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak)
- Add bearer token auth on all endpoints except /health
- Add composite index on (user_id, room_id) for conversation_chunks
- Bot: guard query_chunks with sender check, pass room_id, send auth token
- Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Christian Gick
2026-03-08 13:45:15 +02:00
parent e584ce8ce0
commit 36c7e36456
3 changed files with 92 additions and 31 deletions

36
bot.py
View File

@@ -70,6 +70,7 @@ LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "")
CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "")
CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "")
@@ -444,10 +445,17 @@ class DocumentRAG:
class MemoryClient:
"""Async HTTP client for the memory-service."""
def __init__(self, base_url: str):
def __init__(self, base_url: str, token: str = ""):
self.base_url = base_url.rstrip("/")
self.token = token
self.enabled = bool(base_url)
def _headers(self) -> dict:
h = {}
if self.token:
h["Authorization"] = f"Bearer {self.token}"
return h
async def store(self, user_id: str, fact: str, source_room: str = ""):
if not self.enabled:
return
@@ -456,6 +464,7 @@ class MemoryClient:
await client.post(
f"{self.base_url}/memories/store",
json={"user_id": user_id, "fact": fact, "source_room": source_room},
headers=self._headers(),
)
except Exception:
logger.warning("Memory store failed", exc_info=True)
@@ -468,6 +477,7 @@ class MemoryClient:
resp = await client.post(
f"{self.base_url}/memories/query",
json={"user_id": user_id, "query": query, "top_k": top_k},
headers=self._headers(),
)
resp.raise_for_status()
return resp.json().get("results", [])
@@ -480,7 +490,10 @@ class MemoryClient:
return 0
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
resp = await client.delete(
f"{self.base_url}/memories/{user_id}",
headers=self._headers(),
)
resp.raise_for_status()
return resp.json().get("deleted", 0)
except Exception:
@@ -492,7 +505,10 @@ class MemoryClient:
return []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/memories/{user_id}")
resp = await client.get(
f"{self.base_url}/memories/{user_id}",
headers=self._headers(),
)
resp.raise_for_status()
return resp.json().get("memories", [])
except Exception:
@@ -512,19 +528,24 @@ class MemoryClient:
"chunk_text": chunk_text, "summary": summary,
"source_event_id": source_event_id, "original_ts": original_ts,
},
headers=self._headers(),
)
except Exception:
logger.warning("Chunk store failed", exc_info=True)
async def query_chunks(self, query: str, user_id: str = "", room_id: str = "",
async def query_chunks(self, query: str, user_id: str, room_id: str = "",
top_k: int = 5) -> list[dict]:
if not self.enabled:
return []
if not user_id:
logger.error("query_chunks called with empty user_id — returning empty")
return []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{self.base_url}/chunks/query",
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
headers=self._headers(),
)
resp.raise_for_status()
return resp.json().get("results", [])
@@ -961,7 +982,7 @@ class Bot:
self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY,
rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN)
self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY)
self.memory = MemoryClient(MEMORY_SERVICE_URL)
self.memory = MemoryClient(MEMORY_SERVICE_URL, token=MEMORY_SERVICE_TOKEN)
self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY)
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status
@@ -2092,7 +2113,10 @@ class Bot:
memory_context = self._format_memories(memories)
# Query relevant conversation chunks (RAG over chat history)
chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5)
if sender:
chunks = await self.memory.query_chunks(search_query, user_id=sender, room_id=room.room_id, top_k=5)
else:
chunks = []
chunk_context = self._format_chunks(chunks)
# Include room document context (PDFs, Confluence pages, images uploaded to room)

View File

@@ -20,6 +20,7 @@ services:
- LITELLM_API_KEY
- DEFAULT_MODEL
- MEMORY_SERVICE_URL=http://memory-service:8090
- MEMORY_SERVICE_TOKEN
- PORTAL_URL
- BOT_API_KEY
volumes:
@@ -60,6 +61,7 @@ services:
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
MEMORY_SERVICE_TOKEN: ${MEMORY_SERVICE_TOKEN:-}
depends_on:
memory-db:
condition: service_healthy

View File

@@ -1,5 +1,6 @@
import os
import logging
import secrets
import time
import hashlib
import base64
@@ -7,8 +8,8 @@ import base64
import asyncpg
import httpx
from cryptography.fernet import Fernet
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi import Depends, FastAPI, Header, HTTPException
from pydantic import BaseModel, field_validator
logger = logging.getLogger("memory-service")
logging.basicConfig(level=logging.INFO)
@@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
app = FastAPI(title="Memory Service")
pool: asyncpg.Pool | None = None
owner_pool: asyncpg.Pool | None = None
async def verify_token(authorization: str | None = Header(None)):
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
if not MEMORY_SERVICE_TOKEN:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
raise HTTPException(403, "Invalid token")
def _derive_user_key(user_id: str) -> bytes:
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
if not ENCRYPTION_KEY:
@@ -64,12 +76,26 @@ class StoreRequest(BaseModel):
fact: str
source_room: str = ""
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class QueryRequest(BaseModel):
user_id: str
query: str
top_k: int = 10
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkStoreRequest(BaseModel):
user_id: str
@@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel):
source_event_id: str = ""
original_ts: float = 0.0
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkQueryRequest(BaseModel):
user_id: str = ""
user_id: str # REQUIRED — no default
room_id: str = ""
query: str
top_k: int = 5
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkBulkStoreRequest(BaseModel):
chunks: list[ChunkStoreRequest]
@@ -166,6 +206,9 @@ async def _init_db():
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
""")
finally:
await owner_conn.close()
# Create restricted pool for all request handlers (RLS applies)
@@ -205,7 +248,7 @@ async def health():
@app.post("/memories/store")
async def store_memory(req: StoreRequest):
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
if not req.fact.strip():
raise HTTPException(400, "Empty fact")
@@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest):
@app.post("/memories/query")
async def query_memories(req: QueryRequest):
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
"""Embed query, return top-K similar facts for user."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
@@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest):
@app.delete("/memories/{user_id}")
async def delete_user_memories(user_id: str):
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
"""GDPR delete — remove all memories for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
@@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str):
@app.get("/memories/{user_id}")
async def list_user_memories(user_id: str):
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
"""List all memories for a user (for UI/debug)."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
@@ -314,7 +357,7 @@ async def list_user_memories(user_id: str):
@app.post("/chunks/store")
async def store_chunk(req: ChunkStoreRequest):
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
"""Store a conversation chunk with its summary embedding, encrypted."""
if not req.summary.strip():
raise HTTPException(400, "Empty summary")
@@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest):
@app.post("/chunks/query")
async def query_chunks(req: ChunkQueryRequest):
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
# Build WHERE clause dynamically based on provided filters
conditions = []
params: list = [vec_literal]
idx = 2
# Build WHERE clause — user_id is always required
conditions = [f"user_id = $2"]
params: list = [vec_literal, req.user_id]
idx = 3
if req.user_id:
conditions.append(f"user_id = ${idx}")
params.append(req.user_id)
idx += 1
if req.room_id:
conditions.append(f"room_id = ${idx}")
params.append(req.room_id)
idx += 1
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
where = f"WHERE {' AND '.join(conditions)}"
params.append(req.top_k)
# Determine user_id for RLS and decryption
rls_user = req.user_id if req.user_id else ""
async with pool.acquire() as conn:
if rls_user:
await _set_rls_user(conn, rls_user)
await _set_rls_user(conn, req.user_id)
rows = await conn.fetch(
f"""
@@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest):
@app.post("/chunks/bulk-store")
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
if not req.chunks:
return {"stored": 0}
@@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
@app.get("/chunks/{user_id}/count")
async def count_user_chunks(user_id: str):
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
"""Count conversation chunks for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)