Files
matrix-ai-agent/memory-service/main.py
Christian Gick 36c7e36456 security: enforce per-user data isolation in memory service
- Make user_id required on all request models with field validators
- Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak)
- Add bearer token auth on all endpoints except /health
- Add composite index on (user_id, room_id) for conversation_chunks
- Bot: guard query_chunks with sender check, pass room_id, send auth token
- Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 13:45:15 +02:00

486 lines
16 KiB
Python

import os
import logging
import secrets
import time
import hashlib
import base64
import asyncpg
import httpx
from cryptography.fernet import Fernet
from fastapi import Depends, FastAPI, Header, HTTPException
from pydantic import BaseModel, field_validator
logger = logging.getLogger("memory-service")
logging.basicConfig(level=logging.INFO)
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format(
pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
))
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
app = FastAPI(title="Memory Service")
pool: asyncpg.Pool | None = None
owner_pool: asyncpg.Pool | None = None
async def verify_token(authorization: str | None = Header(None)):
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
if not MEMORY_SERVICE_TOKEN:
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
raise HTTPException(403, "Invalid token")
def _derive_user_key(user_id: str) -> bytes:
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
if not ENCRYPTION_KEY:
raise RuntimeError("MEMORY_ENCRYPTION_KEY not set")
derived = hashlib.pbkdf2_hmac(
"sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1
)
return base64.urlsafe_b64encode(derived)
def _encrypt(text: str, user_id: str) -> str:
"""Encrypt text with per-user Fernet key. Returns base64 ciphertext."""
if not ENCRYPTION_KEY:
return text
f = Fernet(_derive_user_key(user_id))
return f.encrypt(text.encode()).decode()
def _decrypt(ciphertext: str, user_id: str) -> str:
"""Decrypt ciphertext with per-user Fernet key."""
if not ENCRYPTION_KEY:
return ciphertext
try:
f = Fernet(_derive_user_key(user_id))
return f.decrypt(ciphertext.encode()).decode()
except Exception:
# Plaintext fallback for not-yet-migrated rows
return ciphertext
class StoreRequest(BaseModel):
user_id: str
fact: str
source_room: str = ""
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class QueryRequest(BaseModel):
user_id: str
query: str
top_k: int = 10
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkStoreRequest(BaseModel):
user_id: str
room_id: str
chunk_text: str
summary: str
source_event_id: str = ""
original_ts: float = 0.0
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkQueryRequest(BaseModel):
user_id: str # REQUIRED — no default
room_id: str = ""
query: str
top_k: int = 5
@field_validator('user_id')
@classmethod
def user_id_not_empty(cls, v):
if not v or not v.strip():
raise ValueError("user_id is required")
return v.strip()
class ChunkBulkStoreRequest(BaseModel):
chunks: list[ChunkStoreRequest]
async def _embed(text: str) -> list[float]:
"""Get embedding vector from LiteLLM /embeddings endpoint."""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{LITELLM_URL}/embeddings",
json={"model": EMBED_MODEL, "input": text},
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
)
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]
async def _embed_batch(texts: list[str]) -> list[list[float]]:
"""Get embedding vectors for a batch of texts."""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
f"{LITELLM_URL}/embeddings",
json={"model": EMBED_MODEL, "input": texts},
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
)
resp.raise_for_status()
data = resp.json()["data"]
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
async def _set_rls_user(conn, user_id: str):
"""Set the RLS session variable for the current connection."""
await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id)
async def _init_db():
"""Create pgvector extension and memories table if not exists."""
global pool
# Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool
owner_conn = await asyncpg.connect(OWNER_DSN)
conn = owner_conn
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS memories (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
fact TEXT NOT NULL,
source_room TEXT DEFAULT '',
created_at DOUBLE PRECISION NOT NULL,
embedding vector({EMBED_DIMS}) NOT NULL
)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)
""")
await conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_memories_embedding
ON memories USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 10)
""")
# Conversation chunks table for RAG over chat history
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS conversation_chunks (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
chunk_text TEXT NOT NULL,
summary TEXT NOT NULL,
source_event_id TEXT DEFAULT '',
original_ts DOUBLE PRECISION NOT NULL,
embedding vector({EMBED_DIMS}) NOT NULL
)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
""")
finally:
await owner_conn.close()
# Create restricted pool for all request handlers (RLS applies)
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
# Owner pool for admin queries (bypasses RLS) — 1 connection only
global owner_pool
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF")
@app.on_event("startup")
async def startup():
await _init_db()
@app.on_event("shutdown")
async def shutdown():
if pool:
await pool.close()
if owner_pool:
await owner_pool.close()
@app.get("/health")
async def health():
if owner_pool:
async with owner_pool.acquire() as conn:
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
return {
"status": "ok",
"total_memories": mem_count,
"total_chunks": chunk_count,
"encryption": "on" if ENCRYPTION_KEY else "off",
}
return {"status": "no_db"}
@app.post("/memories/store")
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
if not req.fact.strip():
raise HTTPException(400, "Empty fact")
embedding = await _embed(req.fact)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
# Check for duplicates (cosine similarity > threshold)
dup = await conn.fetchval(
"""
SELECT id FROM memories
WHERE user_id = $1
AND 1 - (embedding <=> $2::vector) > $3
LIMIT 1
""",
req.user_id, vec_literal, DEDUP_THRESHOLD,
)
if dup:
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
return {"stored": False, "reason": "duplicate"}
encrypted_fact = _encrypt(req.fact.strip(), req.user_id)
await conn.execute(
"""
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
VALUES ($1, $2, $3, $4, $5::vector)
""",
req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal,
)
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
return {"stored": True}
@app.post("/memories/query")
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
"""Embed query, return top-K similar facts for user."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
rows = await conn.fetch(
"""
SELECT fact, source_room, created_at,
1 - (embedding <=> $1::vector) AS similarity
FROM memories
WHERE user_id = $2
ORDER BY embedding <=> $1::vector
LIMIT $3
""",
vec_literal, req.user_id, req.top_k,
)
results = [
{
"fact": _decrypt(r["fact"], req.user_id),
"source_room": r["source_room"],
"created_at": r["created_at"],
"similarity": float(r["similarity"]),
}
for r in rows
]
return {"results": results}
@app.delete("/memories/{user_id}")
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
"""GDPR delete — remove all memories for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
count = int(result.split()[-1])
logger.info("Deleted %d memories for %s", count, user_id)
return {"deleted": count}
@app.get("/memories/{user_id}")
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
"""List all memories for a user (for UI/debug)."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
rows = await conn.fetch(
"""
SELECT fact, source_room, created_at
FROM memories
WHERE user_id = $1
ORDER BY created_at DESC
""",
user_id,
)
return {
"user_id": user_id,
"count": len(rows),
"memories": [
{"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]}
for r in rows
],
}
# --- Conversation Chunks ---
@app.post("/chunks/store")
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
"""Store a conversation chunk with its summary embedding, encrypted."""
if not req.summary.strip():
raise HTTPException(400, "Empty summary")
embedding = await _embed(req.summary)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
ts = req.original_ts or time.time()
encrypted_text = _encrypt(req.chunk_text, req.user_id)
encrypted_summary = _encrypt(req.summary, req.user_id)
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
await conn.execute(
"""
INSERT INTO conversation_chunks
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
""",
req.user_id, req.room_id, encrypted_text, encrypted_summary,
req.source_event_id, ts, vec_literal,
)
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
return {"stored": True}
@app.post("/chunks/query")
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
# Build WHERE clause — user_id is always required
conditions = [f"user_id = $2"]
params: list = [vec_literal, req.user_id]
idx = 3
if req.room_id:
conditions.append(f"room_id = ${idx}")
params.append(req.room_id)
idx += 1
where = f"WHERE {' AND '.join(conditions)}"
params.append(req.top_k)
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
rows = await conn.fetch(
f"""
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
1 - (embedding <=> $1::vector) AS similarity
FROM conversation_chunks
{where}
ORDER BY embedding <=> $1::vector
LIMIT ${idx}
""",
*params,
)
results = [
{
"chunk_text": _decrypt(r["chunk_text"], r["user_id"]),
"summary": _decrypt(r["summary"], r["user_id"]),
"room_id": r["room_id"],
"user_id": r["user_id"],
"original_ts": r["original_ts"],
"source_event_id": r["source_event_id"],
"similarity": float(r["similarity"]),
}
for r in rows
]
return {"results": results}
@app.post("/chunks/bulk-store")
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
if not req.chunks:
return {"stored": 0}
stored = 0
batch_size = 20
for i in range(0, len(req.chunks), batch_size):
batch = req.chunks[i:i + batch_size]
summaries = [c.summary.strip() for c in batch]
try:
embeddings = await _embed_batch(summaries)
except Exception:
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
continue
async with pool.acquire() as conn:
for chunk, embedding in zip(batch, embeddings):
await _set_rls_user(conn, chunk.user_id)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
ts = chunk.original_ts or time.time()
encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id)
encrypted_summary = _encrypt(chunk.summary, chunk.user_id)
await conn.execute(
"""
INSERT INTO conversation_chunks
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
""",
chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary,
chunk.source_event_id, ts, vec_literal,
)
stored += 1
logger.info("Bulk stored %d chunks", stored)
return {"stored": stored}
@app.get("/chunks/{user_id}/count")
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
"""Count conversation chunks for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
count = await conn.fetchval(
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
)
return {"user_id": user_id, "count": count}