import os import logging import time import hashlib import base64 import asyncpg import httpx from cryptography.fernet import Fernet from fastapi import FastAPI, HTTPException from pydantic import BaseModel logger = logging.getLogger("memory-service") logging.basicConfig(level=logging.INFO) DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format( pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory") )) LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") app = FastAPI(title="Memory Service") pool: asyncpg.Pool | None = None owner_pool: asyncpg.Pool | None = None def _derive_user_key(user_id: str) -> bytes: """Derive a per-user Fernet key from master key + user_id via HMAC-SHA256.""" if not ENCRYPTION_KEY: raise RuntimeError("MEMORY_ENCRYPTION_KEY not set") derived = hashlib.pbkdf2_hmac( "sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1 ) return base64.urlsafe_b64encode(derived) def _encrypt(text: str, user_id: str) -> str: """Encrypt text with per-user Fernet key. Returns base64 ciphertext.""" if not ENCRYPTION_KEY: return text f = Fernet(_derive_user_key(user_id)) return f.encrypt(text.encode()).decode() def _decrypt(ciphertext: str, user_id: str) -> str: """Decrypt ciphertext with per-user Fernet key.""" if not ENCRYPTION_KEY: return ciphertext try: f = Fernet(_derive_user_key(user_id)) return f.decrypt(ciphertext.encode()).decode() except Exception: # Plaintext fallback for not-yet-migrated rows return ciphertext class StoreRequest(BaseModel): user_id: str fact: str source_room: str = "" class QueryRequest(BaseModel): user_id: str query: str top_k: int = 10 class ChunkStoreRequest(BaseModel): user_id: str room_id: str chunk_text: str summary: str source_event_id: str = "" original_ts: float = 0.0 class ChunkQueryRequest(BaseModel): user_id: str = "" room_id: str = "" query: str top_k: int = 5 class ChunkBulkStoreRequest(BaseModel): chunks: list[ChunkStoreRequest] async def _embed(text: str) -> list[float]: """Get embedding vector from LiteLLM /embeddings endpoint.""" async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": text}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() return resp.json()["data"][0]["embedding"] async def _embed_batch(texts: list[str]) -> list[list[float]]: """Get embedding vectors for a batch of texts.""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": texts}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() data = resp.json()["data"] return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] async def _set_rls_user(conn, user_id: str): """Set the RLS session variable for the current connection.""" await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id) async def _init_db(): """Create pgvector extension and memories table if not exists.""" global pool # Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool owner_conn = await asyncpg.connect(OWNER_DSN) conn = owner_conn try: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") await conn.execute(f""" CREATE TABLE IF NOT EXISTS memories ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, fact TEXT NOT NULL, source_room TEXT DEFAULT '', created_at DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id) """) await conn.execute(f""" CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING ivfflat (embedding vector_cosine_ops) WITH (lists = 10) """) # Conversation chunks table for RAG over chat history await conn.execute(f""" CREATE TABLE IF NOT EXISTS conversation_chunks ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, room_id TEXT NOT NULL, chunk_text TEXT NOT NULL, summary TEXT NOT NULL, source_event_id TEXT DEFAULT '', original_ts DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) """) finally: await owner_conn.close() # Create restricted pool for all request handlers (RLS applies) pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) # Owner pool for admin queries (bypasses RLS) — 1 connection only global owner_pool owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2) logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF") @app.on_event("startup") async def startup(): await _init_db() @app.on_event("shutdown") async def shutdown(): if pool: await pool.close() if owner_pool: await owner_pool.close() @app.get("/health") async def health(): if owner_pool: async with owner_pool.acquire() as conn: mem_count = await conn.fetchval("SELECT count(*) FROM memories") chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks") return { "status": "ok", "total_memories": mem_count, "total_chunks": chunk_count, "encryption": "on" if ENCRYPTION_KEY else "off", } return {"status": "no_db"} @app.post("/memories/store") async def store_memory(req: StoreRequest): """Embed fact, deduplicate by cosine similarity, insert encrypted.""" if not req.fact.strip(): raise HTTPException(400, "Empty fact") embedding = await _embed(req.fact) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) # Check for duplicates (cosine similarity > threshold) dup = await conn.fetchval( """ SELECT id FROM memories WHERE user_id = $1 AND 1 - (embedding <=> $2::vector) > $3 LIMIT 1 """, req.user_id, vec_literal, DEDUP_THRESHOLD, ) if dup: logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup) return {"stored": False, "reason": "duplicate"} encrypted_fact = _encrypt(req.fact.strip(), req.user_id) await conn.execute( """ INSERT INTO memories (user_id, fact, source_room, created_at, embedding) VALUES ($1, $2, $3, $4, $5::vector) """, req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal, ) logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60]) return {"stored": True} @app.post("/memories/query") async def query_memories(req: QueryRequest): """Embed query, return top-K similar facts for user.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at, 1 - (embedding <=> $1::vector) AS similarity FROM memories WHERE user_id = $2 ORDER BY embedding <=> $1::vector LIMIT $3 """, vec_literal, req.user_id, req.top_k, ) results = [ { "fact": _decrypt(r["fact"], req.user_id), "source_room": r["source_room"], "created_at": r["created_at"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.delete("/memories/{user_id}") async def delete_user_memories(user_id: str): """GDPR delete — remove all memories for a user.""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id) count = int(result.split()[-1]) logger.info("Deleted %d memories for %s", count, user_id) return {"deleted": count} @app.get("/memories/{user_id}") async def list_user_memories(user_id: str): """List all memories for a user (for UI/debug).""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at FROM memories WHERE user_id = $1 ORDER BY created_at DESC """, user_id, ) return { "user_id": user_id, "count": len(rows), "memories": [ {"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]} for r in rows ], } # --- Conversation Chunks --- @app.post("/chunks/store") async def store_chunk(req: ChunkStoreRequest): """Store a conversation chunk with its summary embedding, encrypted.""" if not req.summary.strip(): raise HTTPException(400, "Empty summary") embedding = await _embed(req.summary) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = req.original_ts or time.time() encrypted_text = _encrypt(req.chunk_text, req.user_id) encrypted_summary = _encrypt(req.summary, req.user_id) async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, req.user_id, req.room_id, encrypted_text, encrypted_summary, req.source_event_id, ts, vec_literal, ) logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60]) return {"stored": True} @app.post("/chunks/query") async def query_chunks(req: ChunkQueryRequest): """Semantic search over conversation chunks. Filter by user_id and/or room_id.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" # Build WHERE clause dynamically based on provided filters conditions = [] params: list = [vec_literal] idx = 2 if req.user_id: conditions.append(f"user_id = ${idx}") params.append(req.user_id) idx += 1 if req.room_id: conditions.append(f"room_id = ${idx}") params.append(req.room_id) idx += 1 where = f"WHERE {' AND '.join(conditions)}" if conditions else "" params.append(req.top_k) # Determine user_id for RLS and decryption rls_user = req.user_id if req.user_id else "" async with pool.acquire() as conn: if rls_user: await _set_rls_user(conn, rls_user) rows = await conn.fetch( f""" SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id, 1 - (embedding <=> $1::vector) AS similarity FROM conversation_chunks {where} ORDER BY embedding <=> $1::vector LIMIT ${idx} """, *params, ) results = [ { "chunk_text": _decrypt(r["chunk_text"], r["user_id"]), "summary": _decrypt(r["summary"], r["user_id"]), "room_id": r["room_id"], "user_id": r["user_id"], "original_ts": r["original_ts"], "source_event_id": r["source_event_id"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.post("/chunks/bulk-store") async def bulk_store_chunks(req: ChunkBulkStoreRequest): """Batch store conversation chunks. Embeds summaries in batches of 20.""" if not req.chunks: return {"stored": 0} stored = 0 batch_size = 20 for i in range(0, len(req.chunks), batch_size): batch = req.chunks[i:i + batch_size] summaries = [c.summary.strip() for c in batch] try: embeddings = await _embed_batch(summaries) except Exception: logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True) continue async with pool.acquire() as conn: for chunk, embedding in zip(batch, embeddings): await _set_rls_user(conn, chunk.user_id) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = chunk.original_ts or time.time() encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id) encrypted_summary = _encrypt(chunk.summary, chunk.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary, chunk.source_event_id, ts, vec_literal, ) stored += 1 logger.info("Bulk stored %d chunks", stored) return {"stored": stored} @app.get("/chunks/{user_id}/count") async def count_user_chunks(user_id: str): """Count conversation chunks for a user.""" async with pool.acquire() as conn: await _set_rls_user(conn, user_id) count = await conn.fetchval( "SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id, ) return {"user_id": user_id, "count": count}