import os import logging import time import asyncpg import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel logger = logging.getLogger("memory-service") logging.basicConfig(level=logging.INFO) DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) app = FastAPI(title="Memory Service") pool: asyncpg.Pool | None = None class StoreRequest(BaseModel): user_id: str fact: str source_room: str = "" class QueryRequest(BaseModel): user_id: str query: str top_k: int = 10 async def _embed(text: str) -> list[float]: """Get embedding vector from LiteLLM /embeddings endpoint.""" async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": text}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() return resp.json()["data"][0]["embedding"] async def _init_db(): """Create pgvector extension and memories table if not exists.""" global pool pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) async with pool.acquire() as conn: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") await conn.execute(f""" CREATE TABLE IF NOT EXISTS memories ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, fact TEXT NOT NULL, source_room TEXT DEFAULT '', created_at DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id) """) await conn.execute(f""" CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING ivfflat (embedding vector_cosine_ops) WITH (lists = 10) """) logger.info("Database initialized (dims=%d)", EMBED_DIMS) @app.on_event("startup") async def startup(): await _init_db() @app.on_event("shutdown") async def shutdown(): if pool: await pool.close() @app.get("/health") async def health(): if pool: async with pool.acquire() as conn: count = await conn.fetchval("SELECT count(*) FROM memories") return {"status": "ok", "total_memories": count} return {"status": "no_db"} @app.post("/memories/store") async def store_memory(req: StoreRequest): """Embed fact, deduplicate by cosine similarity, insert.""" if not req.fact.strip(): raise HTTPException(400, "Empty fact") embedding = await _embed(req.fact) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: # Check for duplicates (cosine similarity > threshold) dup = await conn.fetchval( """ SELECT id FROM memories WHERE user_id = $1 AND 1 - (embedding <=> $2::vector) > $3 LIMIT 1 """, req.user_id, vec_literal, DEDUP_THRESHOLD, ) if dup: logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup) return {"stored": False, "reason": "duplicate"} await conn.execute( """ INSERT INTO memories (user_id, fact, source_room, created_at, embedding) VALUES ($1, $2, $3, $4, $5::vector) """, req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal, ) logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60]) return {"stored": True} @app.post("/memories/query") async def query_memories(req: QueryRequest): """Embed query, return top-K similar facts for user.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: rows = await conn.fetch( """ SELECT fact, source_room, created_at, 1 - (embedding <=> $1::vector) AS similarity FROM memories WHERE user_id = $2 ORDER BY embedding <=> $1::vector LIMIT $3 """, vec_literal, req.user_id, req.top_k, ) results = [ { "fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.delete("/memories/{user_id}") async def delete_user_memories(user_id: str): """GDPR delete — remove all memories for a user.""" async with pool.acquire() as conn: result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id) count = int(result.split()[-1]) logger.info("Deleted %d memories for %s", count, user_id) return {"deleted": count} @app.get("/memories/{user_id}") async def list_user_memories(user_id: str): """List all memories for a user (for UI/debug).""" async with pool.acquire() as conn: rows = await conn.fetch( """ SELECT fact, source_room, created_at FROM memories WHERE user_id = $1 ORDER BY created_at DESC """, user_id, ) return { "user_id": user_id, "count": len(rows), "memories": [ {"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]} for r in rows ], }