import os import logging import secrets import time import hashlib import base64 import asyncpg import httpx from cryptography.fernet import Fernet from fastapi import Depends, FastAPI, Header, HTTPException from pydantic import BaseModel, field_validator logger = logging.getLogger("memory-service") logging.basicConfig(level=logging.INFO) DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format( pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory") )) LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "") app = FastAPI(title="Memory Service") pool: asyncpg.Pool | None = None owner_pool: asyncpg.Pool | None = None async def verify_token(authorization: str | None = Header(None)): """Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode).""" if not MEMORY_SERVICE_TOKEN: return if not authorization or not authorization.startswith("Bearer "): raise HTTPException(401, "Missing or invalid Authorization header") if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN): raise HTTPException(403, "Invalid token") def _derive_user_key(user_id: str) -> bytes: """Derive a per-user Fernet key from master key + user_id via HMAC-SHA256.""" if not ENCRYPTION_KEY: raise RuntimeError("MEMORY_ENCRYPTION_KEY not set") derived = hashlib.pbkdf2_hmac( "sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1 ) return base64.urlsafe_b64encode(derived) def _encrypt(text: str, user_id: str) -> str: """Encrypt text with per-user Fernet key. Returns base64 ciphertext.""" if not ENCRYPTION_KEY: return text f = Fernet(_derive_user_key(user_id)) return f.encrypt(text.encode()).decode() def _decrypt(ciphertext: str, user_id: str) -> str: """Decrypt ciphertext with per-user Fernet key.""" if not ENCRYPTION_KEY: return ciphertext try: f = Fernet(_derive_user_key(user_id)) return f.decrypt(ciphertext.encode()).decode() except Exception: # Plaintext fallback for not-yet-migrated rows return ciphertext class StoreRequest(BaseModel): user_id: str fact: str source_room: str = "" @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class QueryRequest(BaseModel): user_id: str query: str top_k: int = 10 @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class ChunkStoreRequest(BaseModel): user_id: str room_id: str chunk_text: str summary: str source_event_id: str = "" original_ts: float = 0.0 @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class ChunkQueryRequest(BaseModel): user_id: str # REQUIRED — no default room_id: str = "" query: str top_k: int = 5 @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class ChunkBulkStoreRequest(BaseModel): chunks: list[ChunkStoreRequest] async def _embed(text: str) -> list[float]: """Get embedding vector from LiteLLM /embeddings endpoint.""" async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": text}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() return resp.json()["data"][0]["embedding"] async def _embed_batch(texts: list[str]) -> list[list[float]]: """Get embedding vectors for a batch of texts.""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": texts}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() data = resp.json()["data"] return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] async def _set_rls_user(conn, user_id: str): """Set the RLS session variable for the current connection.""" await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id) async def _init_db(): """Create pgvector extension and memories table if not exists.""" global pool # Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool owner_conn = await asyncpg.connect(OWNER_DSN) conn = owner_conn try: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") await conn.execute(f""" CREATE TABLE IF NOT EXISTS memories ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, fact TEXT NOT NULL, source_room TEXT DEFAULT '', created_at DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id) """) await conn.execute(f""" CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING ivfflat (embedding vector_cosine_ops) WITH (lists = 10) """) # Conversation chunks table for RAG over chat history await conn.execute(f""" CREATE TABLE IF NOT EXISTS conversation_chunks ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, room_id TEXT NOT NULL, chunk_text TEXT NOT NULL, summary TEXT NOT NULL, source_event_id TEXT DEFAULT '', original_ts DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id) """) finally: await owner_conn.close() # Create restricted pool for all request handlers (RLS applies) pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) # Owner pool for admin queries (bypasses RLS) — 1 connection only global owner_pool owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2) logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF") @app.on_event("startup") async def startup(): await _init_db() @app.on_event("shutdown") async def shutdown(): if pool: await pool.close() if owner_pool: await owner_pool.close() @app.get("/health") async def health(): if owner_pool: async with owner_pool.acquire() as conn: mem_count = await conn.fetchval("SELECT count(*) FROM memories") chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks") return { "status": "ok", "total_memories": mem_count, "total_chunks": chunk_count, "encryption": "on" if ENCRYPTION_KEY else "off", } return {"status": "no_db"} @app.post("/memories/store") async def store_memory(req: StoreRequest, _: None = Depends(verify_token)): """Embed fact, deduplicate by cosine similarity, insert encrypted.""" if not req.fact.strip(): raise HTTPException(400, "Empty fact") embedding = await _embed(req.fact) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) # Check for duplicates (cosine similarity > threshold) dup = await conn.fetchval( """ SELECT id FROM memories WHERE user_id = $1 AND 1 - (embedding <=> $2::vector) > $3 LIMIT 1 """, req.user_id, vec_literal, DEDUP_THRESHOLD, ) if dup: logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup) return {"stored": False, "reason": "duplicate"} encrypted_fact = _encrypt(req.fact.strip(), req.user_id) await conn.execute( """ INSERT INTO memories (user_id, fact, source_room, created_at, embedding) VALUES ($1, $2, $3, $4, $5::vector) """, req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal, ) logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60]) return {"stored": True} @app.post("/memories/query") async def query_memories(req: QueryRequest, _: None = Depends(verify_token)): """Embed query, return top-K similar facts for user.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at, 1 - (embedding <=> $1::vector) AS similarity FROM memories WHERE user_id = $2 ORDER BY embedding <=> $1::vector LIMIT $3 """, vec_literal, req.user_id, req.top_k, ) results = [ { "fact": _decrypt(r["fact"], req.user_id), "source_room": r["source_room"], "created_at": r["created_at"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.delete("/memories/{user_id}") async def delete_user_memories(user_id: str, _: None = Depends(verify_token)): """GDPR delete — remove all memories for a user.""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id) count = int(result.split()[-1]) logger.info("Deleted %d memories for %s", count, user_id) return {"deleted": count} @app.get("/memories/{user_id}") async def list_user_memories(user_id: str, _: None = Depends(verify_token)): """List all memories for a user (for UI/debug).""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at FROM memories WHERE user_id = $1 ORDER BY created_at DESC """, user_id, ) return { "user_id": user_id, "count": len(rows), "memories": [ {"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]} for r in rows ], } # --- Conversation Chunks --- @app.post("/chunks/store") async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)): """Store a conversation chunk with its summary embedding, encrypted.""" if not req.summary.strip(): raise HTTPException(400, "Empty summary") embedding = await _embed(req.summary) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = req.original_ts or time.time() encrypted_text = _encrypt(req.chunk_text, req.user_id) encrypted_summary = _encrypt(req.summary, req.user_id) async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, req.user_id, req.room_id, encrypted_text, encrypted_summary, req.source_event_id, ts, vec_literal, ) logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60]) return {"stored": True} @app.post("/chunks/query") async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)): """Semantic search over conversation chunks. Filter by user_id and/or room_id.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" # Build WHERE clause — user_id is always required conditions = [f"user_id = $2"] params: list = [vec_literal, req.user_id] idx = 3 if req.room_id: conditions.append(f"room_id = ${idx}") params.append(req.room_id) idx += 1 where = f"WHERE {' AND '.join(conditions)}" params.append(req.top_k) async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) rows = await conn.fetch( f""" SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id, 1 - (embedding <=> $1::vector) AS similarity FROM conversation_chunks {where} ORDER BY embedding <=> $1::vector LIMIT ${idx} """, *params, ) results = [ { "chunk_text": _decrypt(r["chunk_text"], r["user_id"]), "summary": _decrypt(r["summary"], r["user_id"]), "room_id": r["room_id"], "user_id": r["user_id"], "original_ts": r["original_ts"], "source_event_id": r["source_event_id"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.post("/chunks/bulk-store") async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)): """Batch store conversation chunks. Embeds summaries in batches of 20.""" if not req.chunks: return {"stored": 0} stored = 0 batch_size = 20 for i in range(0, len(req.chunks), batch_size): batch = req.chunks[i:i + batch_size] summaries = [c.summary.strip() for c in batch] try: embeddings = await _embed_batch(summaries) except Exception: logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True) continue async with pool.acquire() as conn: for chunk, embedding in zip(batch, embeddings): await _set_rls_user(conn, chunk.user_id) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = chunk.original_ts or time.time() encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id) encrypted_summary = _encrypt(chunk.summary, chunk.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary, chunk.source_event_id, ts, vec_literal, ) stored += 1 logger.info("Bulk stored %d chunks", stored) return {"stored": stored} @app.get("/chunks/{user_id}/count") async def count_user_chunks(user_id: str, _: None = Depends(verify_token)): """Count conversation chunks for a user.""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) count = await conn.fetchval( "SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id, ) return {"user_id": user_id, "count": count}