Files
matrix-ai-agent/memory-service/main.py
Christian Gick 4cd7a0262e feat: Replace JSON memory with pgvector semantic search (MAT-11)
Add memory-service (FastAPI + pgvector) for semantic memory storage.
Bot now queries relevant memories per conversation instead of dumping all 50.
Includes migration script for existing JSON files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 06:25:50 +02:00

192 lines
5.9 KiB
Python

import os
import logging
import time
import asyncpg
import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logger = logging.getLogger("memory-service")
logging.basicConfig(level=logging.INFO)
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
app = FastAPI(title="Memory Service")
pool: asyncpg.Pool | None = None
class StoreRequest(BaseModel):
user_id: str
fact: str
source_room: str = ""
class QueryRequest(BaseModel):
user_id: str
query: str
top_k: int = 10
async def _embed(text: str) -> list[float]:
"""Get embedding vector from LiteLLM /embeddings endpoint."""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{LITELLM_URL}/embeddings",
json={"model": EMBED_MODEL, "input": text},
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
)
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]
async def _init_db():
"""Create pgvector extension and memories table if not exists."""
global pool
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
async with pool.acquire() as conn:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS memories (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
fact TEXT NOT NULL,
source_room TEXT DEFAULT '',
created_at DOUBLE PRECISION NOT NULL,
embedding vector({EMBED_DIMS}) NOT NULL
)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)
""")
await conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_memories_embedding
ON memories USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 10)
""")
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
@app.on_event("startup")
async def startup():
await _init_db()
@app.on_event("shutdown")
async def shutdown():
if pool:
await pool.close()
@app.get("/health")
async def health():
if pool:
async with pool.acquire() as conn:
count = await conn.fetchval("SELECT count(*) FROM memories")
return {"status": "ok", "total_memories": count}
return {"status": "no_db"}
@app.post("/memories/store")
async def store_memory(req: StoreRequest):
"""Embed fact, deduplicate by cosine similarity, insert."""
if not req.fact.strip():
raise HTTPException(400, "Empty fact")
embedding = await _embed(req.fact)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
async with pool.acquire() as conn:
# Check for duplicates (cosine similarity > threshold)
dup = await conn.fetchval(
"""
SELECT id FROM memories
WHERE user_id = $1
AND 1 - (embedding <=> $2::vector) > $3
LIMIT 1
""",
req.user_id, vec_literal, DEDUP_THRESHOLD,
)
if dup:
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
return {"stored": False, "reason": "duplicate"}
await conn.execute(
"""
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
VALUES ($1, $2, $3, $4, $5::vector)
""",
req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal,
)
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
return {"stored": True}
@app.post("/memories/query")
async def query_memories(req: QueryRequest):
"""Embed query, return top-K similar facts for user."""
embedding = await _embed(req.query)
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT fact, source_room, created_at,
1 - (embedding <=> $1::vector) AS similarity
FROM memories
WHERE user_id = $2
ORDER BY embedding <=> $1::vector
LIMIT $3
""",
vec_literal, req.user_id, req.top_k,
)
results = [
{
"fact": r["fact"],
"source_room": r["source_room"],
"created_at": r["created_at"],
"similarity": float(r["similarity"]),
}
for r in rows
]
return {"results": results}
@app.delete("/memories/{user_id}")
async def delete_user_memories(user_id: str):
"""GDPR delete — remove all memories for a user."""
async with pool.acquire() as conn:
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
count = int(result.split()[-1])
logger.info("Deleted %d memories for %s", count, user_id)
return {"deleted": count}
@app.get("/memories/{user_id}")
async def list_user_memories(user_id: str):
"""List all memories for a user (for UI/debug)."""
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT fact, source_room, created_at
FROM memories
WHERE user_id = $1
ORDER BY created_at DESC
""",
user_id,
)
return {
"user_id": user_id,
"count": len(rows),
"memories": [
{"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]}
for r in rows
],
}