security: enforce per-user data isolation in memory service
- Make user_id required on all request models with field validators - Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak) - Add bearer token auth on all endpoints except /health - Add composite index on (user_id, room_id) for conversation_chunks - Bot: guard query_chunks with sender check, pass room_id, send auth token - Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
36
bot.py
36
bot.py
@@ -70,6 +70,7 @@ LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
||||
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
|
||||
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||
CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "")
|
||||
CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "")
|
||||
CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "")
|
||||
@@ -444,10 +445,17 @@ class DocumentRAG:
|
||||
class MemoryClient:
|
||||
"""Async HTTP client for the memory-service."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
def __init__(self, base_url: str, token: str = ""):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.enabled = bool(base_url)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {}
|
||||
if self.token:
|
||||
h["Authorization"] = f"Bearer {self.token}"
|
||||
return h
|
||||
|
||||
async def store(self, user_id: str, fact: str, source_room: str = ""):
|
||||
if not self.enabled:
|
||||
return
|
||||
@@ -456,6 +464,7 @@ class MemoryClient:
|
||||
await client.post(
|
||||
f"{self.base_url}/memories/store",
|
||||
json={"user_id": user_id, "fact": fact, "source_room": source_room},
|
||||
headers=self._headers(),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Memory store failed", exc_info=True)
|
||||
@@ -468,6 +477,7 @@ class MemoryClient:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/memories/query",
|
||||
json={"user_id": user_id, "query": query, "top_k": top_k},
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("results", [])
|
||||
@@ -480,7 +490,10 @@ class MemoryClient:
|
||||
return 0
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/memories/{user_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("deleted", 0)
|
||||
except Exception:
|
||||
@@ -492,7 +505,10 @@ class MemoryClient:
|
||||
return []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/memories/{user_id}")
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/memories/{user_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("memories", [])
|
||||
except Exception:
|
||||
@@ -512,19 +528,24 @@ class MemoryClient:
|
||||
"chunk_text": chunk_text, "summary": summary,
|
||||
"source_event_id": source_event_id, "original_ts": original_ts,
|
||||
},
|
||||
headers=self._headers(),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Chunk store failed", exc_info=True)
|
||||
|
||||
async def query_chunks(self, query: str, user_id: str = "", room_id: str = "",
|
||||
async def query_chunks(self, query: str, user_id: str, room_id: str = "",
|
||||
top_k: int = 5) -> list[dict]:
|
||||
if not self.enabled:
|
||||
return []
|
||||
if not user_id:
|
||||
logger.error("query_chunks called with empty user_id — returning empty")
|
||||
return []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/chunks/query",
|
||||
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("results", [])
|
||||
@@ -961,7 +982,7 @@ class Bot:
|
||||
self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY,
|
||||
rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN)
|
||||
self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY)
|
||||
self.memory = MemoryClient(MEMORY_SERVICE_URL)
|
||||
self.memory = MemoryClient(MEMORY_SERVICE_URL, token=MEMORY_SERVICE_TOKEN)
|
||||
self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY)
|
||||
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
||||
self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status
|
||||
@@ -2092,7 +2113,10 @@ class Bot:
|
||||
memory_context = self._format_memories(memories)
|
||||
|
||||
# Query relevant conversation chunks (RAG over chat history)
|
||||
chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5)
|
||||
if sender:
|
||||
chunks = await self.memory.query_chunks(search_query, user_id=sender, room_id=room.room_id, top_k=5)
|
||||
else:
|
||||
chunks = []
|
||||
chunk_context = self._format_chunks(chunks)
|
||||
|
||||
# Include room document context (PDFs, Confluence pages, images uploaded to room)
|
||||
|
||||
@@ -20,6 +20,7 @@ services:
|
||||
- LITELLM_API_KEY
|
||||
- DEFAULT_MODEL
|
||||
- MEMORY_SERVICE_URL=http://memory-service:8090
|
||||
- MEMORY_SERVICE_TOKEN
|
||||
- PORTAL_URL
|
||||
- BOT_API_KEY
|
||||
volumes:
|
||||
@@ -60,6 +61,7 @@ services:
|
||||
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
||||
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
|
||||
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
||||
MEMORY_SERVICE_TOKEN: ${MEMORY_SERVICE_TOKEN:-}
|
||||
depends_on:
|
||||
memory-db:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
import hashlib
|
||||
import base64
|
||||
@@ -7,8 +8,8 @@ import base64
|
||||
import asyncpg
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger("memory-service")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
||||
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
||||
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
||||
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||
|
||||
app = FastAPI(title="Memory Service")
|
||||
pool: asyncpg.Pool | None = None
|
||||
owner_pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
async def verify_token(authorization: str | None = Header(None)):
|
||||
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
|
||||
if not MEMORY_SERVICE_TOKEN:
|
||||
return
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or invalid Authorization header")
|
||||
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
|
||||
raise HTTPException(403, "Invalid token")
|
||||
|
||||
|
||||
def _derive_user_key(user_id: str) -> bytes:
|
||||
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
||||
if not ENCRYPTION_KEY:
|
||||
@@ -64,12 +76,26 @@ class StoreRequest(BaseModel):
|
||||
fact: str
|
||||
source_room: str = ""
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
user_id: str
|
||||
query: str
|
||||
top_k: int = 10
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkStoreRequest(BaseModel):
|
||||
user_id: str
|
||||
@@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel):
|
||||
source_event_id: str = ""
|
||||
original_ts: float = 0.0
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
user_id: str = ""
|
||||
user_id: str # REQUIRED — no default
|
||||
room_id: str = ""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkBulkStoreRequest(BaseModel):
|
||||
chunks: list[ChunkStoreRequest]
|
||||
@@ -166,6 +206,9 @@ async def _init_db():
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
|
||||
""")
|
||||
finally:
|
||||
await owner_conn.close()
|
||||
# Create restricted pool for all request handlers (RLS applies)
|
||||
@@ -205,7 +248,7 @@ async def health():
|
||||
|
||||
|
||||
@app.post("/memories/store")
|
||||
async def store_memory(req: StoreRequest):
|
||||
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
|
||||
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
||||
if not req.fact.strip():
|
||||
raise HTTPException(400, "Empty fact")
|
||||
@@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest):
|
||||
|
||||
|
||||
@app.post("/memories/query")
|
||||
async def query_memories(req: QueryRequest):
|
||||
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
||||
"""Embed query, return top-K similar facts for user."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
@@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest):
|
||||
|
||||
|
||||
@app.delete("/memories/{user_id}")
|
||||
async def delete_user_memories(user_id: str):
|
||||
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""GDPR delete — remove all memories for a user."""
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
@@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str):
|
||||
|
||||
|
||||
@app.get("/memories/{user_id}")
|
||||
async def list_user_memories(user_id: str):
|
||||
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""List all memories for a user (for UI/debug)."""
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
@@ -314,7 +357,7 @@ async def list_user_memories(user_id: str):
|
||||
|
||||
|
||||
@app.post("/chunks/store")
|
||||
async def store_chunk(req: ChunkStoreRequest):
|
||||
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
|
||||
"""Store a conversation chunk with its summary embedding, encrypted."""
|
||||
if not req.summary.strip():
|
||||
raise HTTPException(400, "Empty summary")
|
||||
@@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest):
|
||||
|
||||
|
||||
@app.post("/chunks/query")
|
||||
async def query_chunks(req: ChunkQueryRequest):
|
||||
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
|
||||
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
# Build WHERE clause dynamically based on provided filters
|
||||
conditions = []
|
||||
params: list = [vec_literal]
|
||||
idx = 2
|
||||
# Build WHERE clause — user_id is always required
|
||||
conditions = [f"user_id = $2"]
|
||||
params: list = [vec_literal, req.user_id]
|
||||
idx = 3
|
||||
|
||||
if req.user_id:
|
||||
conditions.append(f"user_id = ${idx}")
|
||||
params.append(req.user_id)
|
||||
idx += 1
|
||||
if req.room_id:
|
||||
conditions.append(f"room_id = ${idx}")
|
||||
params.append(req.room_id)
|
||||
idx += 1
|
||||
|
||||
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
where = f"WHERE {' AND '.join(conditions)}"
|
||||
params.append(req.top_k)
|
||||
|
||||
# Determine user_id for RLS and decryption
|
||||
rls_user = req.user_id if req.user_id else ""
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
if rls_user:
|
||||
await _set_rls_user(conn, rls_user)
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
@@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest):
|
||||
|
||||
|
||||
@app.post("/chunks/bulk-store")
|
||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
|
||||
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
||||
if not req.chunks:
|
||||
return {"stored": 0}
|
||||
@@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
||||
|
||||
|
||||
@app.get("/chunks/{user_id}/count")
|
||||
async def count_user_chunks(user_id: str):
|
||||
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
|
||||
"""Count conversation chunks for a user."""
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
|
||||
Reference in New Issue
Block a user