Files
matrix-ai-agent/memory-service/main.py
Christian Gick 108144696b feat(MAT-107): memory encryption & user isolation
- Per-user Fernet encryption for fact/chunk_text/summary fields
- Postgres RLS with memory_app restricted role
- SSL for memory-db connections
- Data migration script (migrate_encrypt.py)
- DB migration (migrate_rls.sql)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 15:56:14 +00:00

451 lines
15 KiB
Python

import os
import logging
import time
import hashlib
import base64
import asyncpg
import httpx
from cryptography.fernet import Fernet
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logger = logging.getLogger("memory-service")
logging.basicConfig(level=logging.INFO)
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format(
pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
))
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
app = FastAPI(title="Memory Service")
pool: asyncpg.Pool | None = None
owner_pool: asyncpg.Pool | None = None
def _derive_user_key(user_id: str) -> bytes:
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
if not ENCRYPTION_KEY:
raise RuntimeError("MEMORY_ENCRYPTION_KEY not set")
derived = hashlib.pbkdf2_hmac(
"sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1
)
return base64.urlsafe_b64encode(derived)
def _encrypt(text: str, user_id: str) -> str:
"""Encrypt text with per-user Fernet key. Returns base64 ciphertext."""
if not ENCRYPTION_KEY:
return text
f = Fernet(_derive_user_key(user_id))
return f.encrypt(text.encode()).decode()
def _decrypt(ciphertext: str, user_id: str) -> str:
"""Decrypt ciphertext with per-user Fernet key."""
if not ENCRYPTION_KEY:
return ciphertext
try:
f = Fernet(_derive_user_key(user_id))
return f.decrypt(ciphertext.encode()).decode()
except Exception:
# Plaintext fallback for not-yet-migrated rows
return ciphertext
class StoreRequest(BaseModel):
user_id: str
fact: str
source_room: str = ""
class QueryRequest(BaseModel):
user_id: str
query: str
top_k: int = 10
class ChunkStoreRequest(BaseModel):
user_id: str
room_id: str
chunk_text: str
summary: str
source_event_id: str = ""
original_ts: float = 0.0
class ChunkQueryRequest(BaseModel):
user_id: str = ""
room_id: str = ""
query: str
top_k: int = 5
class ChunkBulkStoreRequest(BaseModel):
chunks: list[ChunkStoreRequest]
async def _embed(text: str) -> list[float]:
"""Get embedding vector from LiteLLM /embeddings endpoint."""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{LITELLM_URL}/embeddings",
json={"model": EMBED_MODEL, "input": text},
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
)
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]
async def _embed_batch(texts: list[str]) -> list[list[float]]:
"""Get embedding vectors for a batch of texts."""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
f"{LITELLM_URL}/embeddings",
json={"model": EMBED_MODEL, "input": texts},
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
)
resp.raise_for_status()
data = resp.json()["data"]
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
async def _set_rls_user(conn, user_id: str):
"""Set the RLS session variable for the current connection."""
await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id)
async def _init_db():
"""Create pgvector extension and memories table if not exists."""
global pool
# Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool
owner_conn = await asyncpg.connect(OWNER_DSN)
conn = owner_conn
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS memories (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
fact TEXT NOT NULL,
source_room TEXT DEFAULT '',
created_at DOUBLE PRECISION NOT NULL,
embedding vector({EMBED_DIMS}) NOT NULL
)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)
""")
await conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_memories_embedding
ON memories USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 10)
""")
# Conversation chunks table for RAG over chat history
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS conversation_chunks (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
chunk_text TEXT NOT NULL,
summary TEXT NOT NULL,
source_event_id TEXT DEFAULT '',
original_ts DOUBLE PRECISION NOT NULL,
embedding vector({EMBED_DIMS}) NOT NULL
)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
""")
finally:
await owner_conn.close()
# Create restricted pool for all request handlers (RLS applies)
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
# Owner pool for admin queries (bypasses RLS) — 1 connection only
global owner_pool
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF")
@app.on_event("startup")
async def startup():
await _init_db()
@app.on_event("shutdown")
async def shutdown():
if pool:
await pool.close()
if owner_pool:
await owner_pool.close()
@app.get("/health")
async def health():
if owner_pool:
async with owner_pool.acquire() as conn:
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
return {
"status": "ok",
"total_memories": mem_count,
"total_chunks": chunk_count,
"encryption": "on" if ENCRYPTION_KEY else "off",
}
return {"status": "no_db"}
@app.post("/memories/store")
async def store_memory(req: StoreRequest):
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
if not req.fact.strip():
raise HTTPException(400, "Empty fact")
embedding = await _embed(req.fact)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
# Check for duplicates (cosine similarity > threshold)
dup = await conn.fetchval(
"""
SELECT id FROM memories
WHERE user_id = $1
AND 1 - (embedding <=> $2::vector) > $3
LIMIT 1
""",
req.user_id, vec_literal, DEDUP_THRESHOLD,
)
if dup:
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
return {"stored": False, "reason": "duplicate"}
encrypted_fact = _encrypt(req.fact.strip(), req.user_id)
await conn.execute(
"""
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
VALUES ($1, $2, $3, $4, $5::vector)
""",
req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal,
)
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
return {"stored": True}
@app.post("/memories/query")
async def query_memories(req: QueryRequest):
"""Embed query, return top-K similar facts for user."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
rows = await conn.fetch(
"""
SELECT fact, source_room, created_at,
1 - (embedding <=> $1::vector) AS similarity
FROM memories
WHERE user_id = $2
ORDER BY embedding <=> $1::vector
LIMIT $3
""",
vec_literal, req.user_id, req.top_k,
)
results = [
{
"fact": _decrypt(r["fact"], req.user_id),
"source_room": r["source_room"],
"created_at": r["created_at"],
"similarity": float(r["similarity"]),
}
for r in rows
]
return {"results": results}
@app.delete("/memories/{user_id}")
async def delete_user_memories(user_id: str):
"""GDPR delete — remove all memories for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
count = int(result.split()[-1])
logger.info("Deleted %d memories for %s", count, user_id)
return {"deleted": count}
@app.get("/memories/{user_id}")
async def list_user_memories(user_id: str):
"""List all memories for a user (for UI/debug)."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
rows = await conn.fetch(
"""
SELECT fact, source_room, created_at
FROM memories
WHERE user_id = $1
ORDER BY created_at DESC
""",
user_id,
)
return {
"user_id": user_id,
"count": len(rows),
"memories": [
{"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]}
for r in rows
],
}
# --- Conversation Chunks ---
@app.post("/chunks/store")
async def store_chunk(req: ChunkStoreRequest):
"""Store a conversation chunk with its summary embedding, encrypted."""
if not req.summary.strip():
raise HTTPException(400, "Empty summary")
embedding = await _embed(req.summary)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
ts = req.original_ts or time.time()
encrypted_text = _encrypt(req.chunk_text, req.user_id)
encrypted_summary = _encrypt(req.summary, req.user_id)
async with pool.acquire() as conn:
await _set_rls_user(conn, req.user_id)
await conn.execute(
"""
INSERT INTO conversation_chunks
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
""",
req.user_id, req.room_id, encrypted_text, encrypted_summary,
req.source_event_id, ts, vec_literal,
)
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
return {"stored": True}
@app.post("/chunks/query")
async def query_chunks(req: ChunkQueryRequest):
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
# Build WHERE clause dynamically based on provided filters
conditions = []
params: list = [vec_literal]
idx = 2
if req.user_id:
conditions.append(f"user_id = ${idx}")
params.append(req.user_id)
idx += 1
if req.room_id:
conditions.append(f"room_id = ${idx}")
params.append(req.room_id)
idx += 1
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
params.append(req.top_k)
# Determine user_id for RLS and decryption
rls_user = req.user_id if req.user_id else ""
async with pool.acquire() as conn:
if rls_user:
await _set_rls_user(conn, rls_user)
rows = await conn.fetch(
f"""
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
1 - (embedding <=> $1::vector) AS similarity
FROM conversation_chunks
{where}
ORDER BY embedding <=> $1::vector
LIMIT ${idx}
""",
*params,
)
results = [
{
"chunk_text": _decrypt(r["chunk_text"], r["user_id"]),
"summary": _decrypt(r["summary"], r["user_id"]),
"room_id": r["room_id"],
"user_id": r["user_id"],
"original_ts": r["original_ts"],
"source_event_id": r["source_event_id"],
"similarity": float(r["similarity"]),
}
for r in rows
]
return {"results": results}
@app.post("/chunks/bulk-store")
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
if not req.chunks:
return {"stored": 0}
stored = 0
batch_size = 20
for i in range(0, len(req.chunks), batch_size):
batch = req.chunks[i:i + batch_size]
summaries = [c.summary.strip() for c in batch]
try:
embeddings = await _embed_batch(summaries)
except Exception:
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
continue
async with pool.acquire() as conn:
for chunk, embedding in zip(batch, embeddings):
await _set_rls_user(conn, chunk.user_id)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
ts = chunk.original_ts or time.time()
encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id)
encrypted_summary = _encrypt(chunk.summary, chunk.user_id)
await conn.execute(
"""
INSERT INTO conversation_chunks
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
""",
chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary,
chunk.source_event_id, ts, vec_literal,
)
stored += 1
logger.info("Bulk stored %d chunks", stored)
return {"stored": stored}
@app.get("/chunks/{user_id}/count")
async def count_user_chunks(user_id: str):
"""Count conversation chunks for a user."""
async with pool.acquire() as conn:
await _set_rls_user(conn, user_id)
count = await conn.fetchval(
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
)
return {"user_id": user_id, "count": count}