security: enforce per-user data isolation in memory service

- Make user_id required on all request models with field validators
- Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak)
- Add bearer token auth on all endpoints except /health
- Add composite index on (user_id, room_id) for conversation_chunks
- Bot: guard query_chunks with sender check, pass room_id, send auth token
- Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Christian Gick
2026-03-08 13:45:15 +02:00
parent e584ce8ce0
commit 36c7e36456
3 changed files with 92 additions and 31 deletions

36
bot.py
View File

@@ -70,6 +70,7 @@ LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090") MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "") CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "")
CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "") CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "")
CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "") CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "")
@@ -444,10 +445,17 @@ class DocumentRAG:
class MemoryClient: class MemoryClient:
"""Async HTTP client for the memory-service.""" """Async HTTP client for the memory-service."""
def __init__(self, base_url: str): def __init__(self, base_url: str, token: str = ""):
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
self.token = token
self.enabled = bool(base_url) self.enabled = bool(base_url)
def _headers(self) -> dict:
h = {}
if self.token:
h["Authorization"] = f"Bearer {self.token}"
return h
async def store(self, user_id: str, fact: str, source_room: str = ""): async def store(self, user_id: str, fact: str, source_room: str = ""):
if not self.enabled: if not self.enabled:
return return
@@ -456,6 +464,7 @@ class MemoryClient:
await client.post( await client.post(
f"{self.base_url}/memories/store", f"{self.base_url}/memories/store",
json={"user_id": user_id, "fact": fact, "source_room": source_room}, json={"user_id": user_id, "fact": fact, "source_room": source_room},
headers=self._headers(),
) )
except Exception: except Exception:
logger.warning("Memory store failed", exc_info=True) logger.warning("Memory store failed", exc_info=True)
@@ -468,6 +477,7 @@ class MemoryClient:
resp = await client.post( resp = await client.post(
f"{self.base_url}/memories/query", f"{self.base_url}/memories/query",
json={"user_id": user_id, "query": query, "top_k": top_k}, json={"user_id": user_id, "query": query, "top_k": top_k},
headers=self._headers(),
) )
resp.raise_for_status() resp.raise_for_status()
return resp.json().get("results", []) return resp.json().get("results", [])
@@ -480,7 +490,10 @@ class MemoryClient:
return 0 return 0
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.delete(f"{self.base_url}/memories/{user_id}") resp = await client.delete(
f"{self.base_url}/memories/{user_id}",
headers=self._headers(),
)
resp.raise_for_status() resp.raise_for_status()
return resp.json().get("deleted", 0) return resp.json().get("deleted", 0)
except Exception: except Exception:
@@ -492,7 +505,10 @@ class MemoryClient:
return [] return []
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/memories/{user_id}") resp = await client.get(
f"{self.base_url}/memories/{user_id}",
headers=self._headers(),
)
resp.raise_for_status() resp.raise_for_status()
return resp.json().get("memories", []) return resp.json().get("memories", [])
except Exception: except Exception:
@@ -512,19 +528,24 @@ class MemoryClient:
"chunk_text": chunk_text, "summary": summary, "chunk_text": chunk_text, "summary": summary,
"source_event_id": source_event_id, "original_ts": original_ts, "source_event_id": source_event_id, "original_ts": original_ts,
}, },
headers=self._headers(),
) )
except Exception: except Exception:
logger.warning("Chunk store failed", exc_info=True) logger.warning("Chunk store failed", exc_info=True)
async def query_chunks(self, query: str, user_id: str = "", room_id: str = "", async def query_chunks(self, query: str, user_id: str, room_id: str = "",
top_k: int = 5) -> list[dict]: top_k: int = 5) -> list[dict]:
if not self.enabled: if not self.enabled:
return [] return []
if not user_id:
logger.error("query_chunks called with empty user_id — returning empty")
return []
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post( resp = await client.post(
f"{self.base_url}/chunks/query", f"{self.base_url}/chunks/query",
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k}, json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
headers=self._headers(),
) )
resp.raise_for_status() resp.raise_for_status()
return resp.json().get("results", []) return resp.json().get("results", [])
@@ -961,7 +982,7 @@ class Bot:
self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY, self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY,
rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN) rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN)
self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY) self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY)
self.memory = MemoryClient(MEMORY_SERVICE_URL) self.memory = MemoryClient(MEMORY_SERVICE_URL, token=MEMORY_SERVICE_TOKEN)
self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY) self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY)
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status
@@ -2092,7 +2113,10 @@ class Bot:
memory_context = self._format_memories(memories) memory_context = self._format_memories(memories)
# Query relevant conversation chunks (RAG over chat history) # Query relevant conversation chunks (RAG over chat history)
chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5) if sender:
chunks = await self.memory.query_chunks(search_query, user_id=sender, room_id=room.room_id, top_k=5)
else:
chunks = []
chunk_context = self._format_chunks(chunks) chunk_context = self._format_chunks(chunks)
# Include room document context (PDFs, Confluence pages, images uploaded to room) # Include room document context (PDFs, Confluence pages, images uploaded to room)

View File

@@ -20,6 +20,7 @@ services:
- LITELLM_API_KEY - LITELLM_API_KEY
- DEFAULT_MODEL - DEFAULT_MODEL
- MEMORY_SERVICE_URL=http://memory-service:8090 - MEMORY_SERVICE_URL=http://memory-service:8090
- MEMORY_SERVICE_TOKEN
- PORTAL_URL - PORTAL_URL
- BOT_API_KEY - BOT_API_KEY
volumes: volumes:
@@ -60,6 +61,7 @@ services:
LITELLM_BASE_URL: ${LITELLM_BASE_URL} LITELLM_BASE_URL: ${LITELLM_BASE_URL}
LITELLM_API_KEY: ${LITELLM_MASTER_KEY} LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small} EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
MEMORY_SERVICE_TOKEN: ${MEMORY_SERVICE_TOKEN:-}
depends_on: depends_on:
memory-db: memory-db:
condition: service_healthy condition: service_healthy

View File

@@ -1,5 +1,6 @@
import os import os
import logging import logging
import secrets
import time import time
import hashlib import hashlib
import base64 import base64
@@ -7,8 +8,8 @@ import base64
import asyncpg import asyncpg
import httpx import httpx
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from fastapi import FastAPI, HTTPException from fastapi import Depends, FastAPI, Header, HTTPException
from pydantic import BaseModel from pydantic import BaseModel, field_validator
logger = logging.getLogger("memory-service") logger = logging.getLogger("memory-service")
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
app = FastAPI(title="Memory Service") app = FastAPI(title="Memory Service")
pool: asyncpg.Pool | None = None pool: asyncpg.Pool | None = None
owner_pool: asyncpg.Pool | None = None owner_pool: asyncpg.Pool | None = None
async def verify_token(authorization: str | None = Header(None)):
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
if not MEMORY_SERVICE_TOKEN:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
raise HTTPException(403, "Invalid token")
def _derive_user_key(user_id: str) -> bytes: def _derive_user_key(user_id: str) -> bytes:
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256.""" """Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
if not ENCRYPTION_KEY: if not ENCRYPTION_KEY:
@@ -64,12 +76,26 @@ class StoreRequest(BaseModel):
fact: str fact: str
source_room: str = "" source_room: str = ""
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
user_id: str user_id: str
query: str query: str
top_k: int = 10 top_k: int = 10
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkStoreRequest(BaseModel): class ChunkStoreRequest(BaseModel):
user_id: str user_id: str
@@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel):
source_event_id: str = "" source_event_id: str = ""
original_ts: float = 0.0 original_ts: float = 0.0
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkQueryRequest(BaseModel): class ChunkQueryRequest(BaseModel):
user_id: str = "" user_id: str # REQUIRED — no default
room_id: str = "" room_id: str = ""
query: str query: str
top_k: int = 5 top_k: int = 5
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkBulkStoreRequest(BaseModel): class ChunkBulkStoreRequest(BaseModel):
chunks: list[ChunkStoreRequest] chunks: list[ChunkStoreRequest]
@@ -166,6 +206,9 @@ async def _init_db():
await conn.execute(""" await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
""") """)
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
""")
finally: finally:
await owner_conn.close() await owner_conn.close()
# Create restricted pool for all request handlers (RLS applies) # Create restricted pool for all request handlers (RLS applies)
@@ -205,7 +248,7 @@ async def health():
@app.post("/memories/store") @app.post("/memories/store")
async def store_memory(req: StoreRequest): async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
"""Embed fact, deduplicate by cosine similarity, insert encrypted.""" """Embed fact, deduplicate by cosine similarity, insert encrypted."""
if not req.fact.strip(): if not req.fact.strip():
raise HTTPException(400, "Empty fact") raise HTTPException(400, "Empty fact")
@@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest):
@app.post("/memories/query") @app.post("/memories/query")
async def query_memories(req: QueryRequest): async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
"""Embed query, return top-K similar facts for user.""" """Embed query, return top-K similar facts for user."""
embedding = await _embed(req.query) embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
@@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest):
@app.delete("/memories/{user_id}") @app.delete("/memories/{user_id}")
async def delete_user_memories(user_id: str): async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
"""GDPR delete — remove all memories for a user.""" """GDPR delete — remove all memories for a user."""
async with pool.acquire() as conn: async with pool.acquire() as conn:
await _set_rls_user(conn, user_id) await _set_rls_user(conn, user_id)
@@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str):
@app.get("/memories/{user_id}") @app.get("/memories/{user_id}")
async def list_user_memories(user_id: str): async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
"""List all memories for a user (for UI/debug).""" """List all memories for a user (for UI/debug)."""
async with pool.acquire() as conn: async with pool.acquire() as conn:
await _set_rls_user(conn, user_id) await _set_rls_user(conn, user_id)
@@ -314,7 +357,7 @@ async def list_user_memories(user_id: str):
@app.post("/chunks/store") @app.post("/chunks/store")
async def store_chunk(req: ChunkStoreRequest): async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
"""Store a conversation chunk with its summary embedding, encrypted.""" """Store a conversation chunk with its summary embedding, encrypted."""
if not req.summary.strip(): if not req.summary.strip():
raise HTTPException(400, "Empty summary") raise HTTPException(400, "Empty summary")
@@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest):
@app.post("/chunks/query") @app.post("/chunks/query")
async def query_chunks(req: ChunkQueryRequest): async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
"""Semantic search over conversation chunks. Filter by user_id and/or room_id.""" """Semantic search over conversation chunks. Filter by user_id and/or room_id."""
embedding = await _embed(req.query) embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
# Build WHERE clause dynamically based on provided filters # Build WHERE clause — user_id is always required
conditions = [] conditions = [f"user_id = $2"]
params: list = [vec_literal] params: list = [vec_literal, req.user_id]
idx = 2 idx = 3
if req.user_id:
conditions.append(f"user_id = ${idx}")
params.append(req.user_id)
idx += 1
if req.room_id: if req.room_id:
conditions.append(f"room_id = ${idx}") conditions.append(f"room_id = ${idx}")
params.append(req.room_id) params.append(req.room_id)
idx += 1 idx += 1
where = f"WHERE {' AND '.join(conditions)}" if conditions else "" where = f"WHERE {' AND '.join(conditions)}"
params.append(req.top_k) params.append(req.top_k)
# Determine user_id for RLS and decryption
rls_user = req.user_id if req.user_id else ""
async with pool.acquire() as conn: async with pool.acquire() as conn:
if rls_user: await _set_rls_user(conn, req.user_id)
await _set_rls_user(conn, rls_user)
rows = await conn.fetch( rows = await conn.fetch(
f""" f"""
@@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest):
@app.post("/chunks/bulk-store") @app.post("/chunks/bulk-store")
async def bulk_store_chunks(req: ChunkBulkStoreRequest): async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
"""Batch store conversation chunks. Embeds summaries in batches of 20.""" """Batch store conversation chunks. Embeds summaries in batches of 20."""
if not req.chunks: if not req.chunks:
return {"stored": 0} return {"stored": 0}
@@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
@app.get("/chunks/{user_id}/count") @app.get("/chunks/{user_id}/count")
async def count_user_chunks(user_id: str): async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
"""Count conversation chunks for a user.""" """Count conversation chunks for a user."""
async with pool.acquire() as conn: async with pool.acquire() as conn:
await _set_rls_user(conn, user_id) await _set_rls_user(conn, user_id)