From 4cd7a0262ecae025aa6d2a79f435294f7f5171dc Mon Sep 17 00:00:00 2001 From: Christian Gick Date: Fri, 20 Feb 2026 06:25:50 +0200 Subject: [PATCH] feat: Replace JSON memory with pgvector semantic search (MAT-11) Add memory-service (FastAPI + pgvector) for semantic memory storage. Bot now queries relevant memories per conversation instead of dumping all 50. Includes migration script for existing JSON files. Co-Authored-By: Claude Opus 4.6 --- bot.py | 153 +++++++++++++------------ docker-compose.yml | 37 +++++++ memory-service/Dockerfile | 6 + memory-service/main.py | 191 ++++++++++++++++++++++++++++++++ memory-service/migrate_json.py | 108 ++++++++++++++++++ memory-service/requirements.txt | 5 + 6 files changed, 432 insertions(+), 68 deletions(-) create mode 100644 memory-service/Dockerfile create mode 100644 memory-service/main.py create mode 100644 memory-service/migrate_json.py create mode 100644 memory-service/requirements.txt diff --git a/bot.py b/bot.py index 6df1b07..8bc307b 100644 --- a/bot.py +++ b/bot.py @@ -8,8 +8,6 @@ import re import time import uuid -import hashlib - import fitz # pymupdf import httpx from openai import AsyncOpenAI @@ -60,8 +58,7 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet") WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "") WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "") USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json") -MEMORIES_DIR = os.environ.get("MEMORIES_DIR", "/data/memories") -MAX_MEMORIES_PER_USER = 50 +MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090") SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room. Keep answers concise but thorough. Use markdown formatting when helpful. @@ -190,6 +187,65 @@ class DocumentRAG: return "\n".join(parts) +class MemoryClient: + """Async HTTP client for the memory-service.""" + + def __init__(self, base_url: str): + self.base_url = base_url.rstrip("/") + self.enabled = bool(base_url) + + async def store(self, user_id: str, fact: str, source_room: str = ""): + if not self.enabled: + return + try: + async with httpx.AsyncClient(timeout=10.0) as client: + await client.post( + f"{self.base_url}/memories/store", + json={"user_id": user_id, "fact": fact, "source_room": source_room}, + ) + except Exception: + logger.warning("Memory store failed", exc_info=True) + + async def query(self, user_id: str, query: str, top_k: int = 10) -> list[dict]: + if not self.enabled: + return [] + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{self.base_url}/memories/query", + json={"user_id": user_id, "query": query, "top_k": top_k}, + ) + resp.raise_for_status() + return resp.json().get("results", []) + except Exception: + logger.warning("Memory query failed", exc_info=True) + return [] + + async def delete_user(self, user_id: str) -> int: + if not self.enabled: + return 0 + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.delete(f"{self.base_url}/memories/{user_id}") + resp.raise_for_status() + return resp.json().get("deleted", 0) + except Exception: + logger.warning("Memory delete failed", exc_info=True) + return 0 + + async def list_all(self, user_id: str) -> list[dict]: + if not self.enabled: + return [] + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(f"{self.base_url}/memories/{user_id}") + resp.raise_for_status() + return resp.json().get("memories", []) + except Exception: + logger.warning("Memory list failed", exc_info=True) + return [] + + class Bot: def __init__(self): config = AsyncClientConfig( @@ -208,6 +264,7 @@ class Bot: self.dispatched_rooms = set() self.active_calls = set() # rooms where we've sent call member event self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG) + self.memory = MemoryClient(MEMORY_SERVICE_URL) self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key self.room_models: dict[str, str] = {} # room_id -> model name @@ -414,46 +471,22 @@ class Bot: # --- User memory helpers --- - def _memory_path(self, user_id: str) -> str: - """Get the file path for a user's memory store.""" - uid_hash = hashlib.sha256(user_id.encode()).hexdigest()[:16] - return os.path.join(MEMORIES_DIR, f"{uid_hash}.json") - - def _load_memories(self, user_id: str) -> list[dict]: - """Load memories for a user. Returns list of {fact, created, source_room}.""" - path = self._memory_path(user_id) - try: - with open(path) as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - return [] - - def _save_memories(self, user_id: str, memories: list[dict]): - """Save memories for a user, capping at MAX_MEMORIES_PER_USER.""" - os.makedirs(MEMORIES_DIR, exist_ok=True) - # Keep only the most recent memories - memories = memories[-MAX_MEMORIES_PER_USER:] - path = self._memory_path(user_id) - with open(path, "w") as f: - json.dump(memories, f, indent=2) - - def _format_memories(self, memories: list[dict]) -> str: - """Format memories as a system prompt section.""" + @staticmethod + def _format_memories(memories: list[dict]) -> str: + """Format memory query results as a system prompt section.""" if not memories: return "" facts = [m["fact"] for m in memories] return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts) - async def _extract_memories(self, user_message: str, ai_reply: str, - existing: list[dict], model: str, - sender: str, room_id: str) -> list[dict]: - """Use LLM to extract memorable facts from the conversation, deduplicate with existing.""" + async def _extract_and_store_memories(self, user_message: str, ai_reply: str, + existing_facts: list[str], model: str, + sender: str, room_id: str): + """Use LLM to extract memorable facts, then store each via memory-service.""" if not self.llm: - return existing + return - existing_facts = [m["fact"] for m in existing] existing_text = "\n".join(f"- {f}" for f in existing_facts) if existing_facts else "(none)" - logger.info("Memory extraction: user_msg=%s... (%d existing facts)", user_message[:80], len(existing_facts)) try: @@ -481,7 +514,6 @@ class Bot: raw = resp.choices[0].message.content.strip() logger.info("Memory extraction raw response: %s", raw[:200]) - # Robust JSON extraction: strip markdown fences, find array if raw.startswith("```"): raw = re.sub(r"^```\w*\n?", "", raw) raw = re.sub(r"\n?```$", "", raw) @@ -491,22 +523,17 @@ class Bot: new_facts = json.loads(raw) if not isinstance(new_facts, list): logger.warning("Memory extraction returned non-list: %s", type(new_facts)) - return existing + return logger.info("Memory extraction found %d new facts", len(new_facts)) - - now = time.time() for fact in new_facts: if isinstance(fact, str) and fact.strip(): - existing.append({"fact": fact.strip(), "created": now, "source_room": room_id}) + await self.memory.store(sender, fact.strip(), room_id) - return existing except json.JSONDecodeError: logger.warning("Memory extraction JSON parse failed, raw: %s", raw[:200]) - return existing except Exception: logger.warning("Memory extraction failed", exc_info=True) - return existing async def _detect_language(self, text: str) -> str: """Detect the language of a text using a fast LLM call.""" @@ -544,9 +571,9 @@ class Bot: logger.debug("Translation failed", exc_info=True) return f"[Translation failed] {text}" - def _get_preferred_language(self, user_id: str) -> str: + async def _get_preferred_language(self, user_id: str) -> str: """Get user's preferred language from memories (last match = most recent).""" - memories = self._load_memories(user_id) + memories = await self.memory.query(user_id, "preferred language", top_k=5) known_langs = [ "English", "German", "French", "Spanish", "Italian", "Portuguese", "Dutch", "Russian", "Chinese", "Japanese", "Korean", "Arabic", @@ -620,7 +647,7 @@ class Bot: if is_dm and sender in self._pending_translate: pending = self._pending_translate.pop(sender) choice = body.strip().lower() - preferred_lang = self._get_preferred_language(sender) + preferred_lang = await self._get_preferred_language(sender) if choice in ("1", "1️⃣") or choice.startswith("translate"): await self.client.room_typing(room.room_id, typing_state=True) @@ -653,7 +680,7 @@ class Bot: # --- DM translation workflow: detect foreign language --- if is_dm and not body.startswith("!ai") and not image_data: - preferred_lang = self._get_preferred_language(sender) + preferred_lang = await self._get_preferred_language(sender) detected_lang = await self._detect_language(body) logger.info("Translation check: detected=%s, preferred=%s, len=%d", detected_lang, preferred_lang, len(body)) if ( @@ -952,22 +979,17 @@ class Bot: elif cmd == "forget": sender = event.sender if event else None if sender: - path = self._memory_path(sender) - try: - os.remove(path) - except FileNotFoundError: - pass - # Clear any in-memory caches for this user + deleted = await self.memory.delete_user(sender) self._pending_translate.pop(sender, None) self._pending_reply.pop(sender, None) - await self._send_text(room.room_id, "All my memories about you have been deleted.") + await self._send_text(room.room_id, f"All my memories about you have been deleted ({deleted} facts removed).") else: await self._send_text(room.room_id, "Could not identify user.") elif cmd == "memories": sender = event.sender if event else None if sender: - memories = self._load_memories(sender) + memories = await self.memory.list_all(sender) if memories: text = f"**I remember {len(memories)} things about you:**\n" text += "\n".join(f"- {m['fact']}" for m in memories) @@ -1148,8 +1170,8 @@ class Bot: else: logger.info("RAG found 0 docs for: %s (original: %s)", search_query[:50], user_message[:50]) - # Load user memories - memories = self._load_memories(sender) if sender else [] + # Query relevant memories via semantic search + memories = await self.memory.query(sender, user_message, top_k=10) if sender else [] memory_context = self._format_memories(memories) # Build conversation context @@ -1191,21 +1213,16 @@ class Bot: else: await self._send_text(room.room_id, reply) - # Extract and save new memories (after reply sent, with timeout) + # Extract and store new memories (after reply sent, with timeout) if sender and reply: + existing_facts = [m["fact"] for m in memories] try: - updated = await asyncio.wait_for( - self._extract_memories( - user_message, reply, memories, model, sender, room.room_id + await asyncio.wait_for( + self._extract_and_store_memories( + user_message, reply, existing_facts, model, sender, room.room_id ), timeout=15.0, ) - if len(updated) > len(memories): - self._save_memories(sender, updated) - logger.info("Saved %d new memories for %s (total: %d)", - len(updated) - len(memories), sender, len(updated)) - else: - logger.info("No new memories extracted for %s", sender) except asyncio.TimeoutError: logger.warning("Memory extraction timed out for %s", sender) except Exception: diff --git a/docker-compose.yml b/docker-compose.yml index d5b08f0..1f8d7c0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,8 +17,45 @@ services: - DEFAULT_MODEL - WILDFILES_BASE_URL - WILDFILES_ORG + - MEMORY_SERVICE_URL=http://memory-service:8090 volumes: - bot-data:/data + depends_on: + memory-service: + condition: service_healthy + + memory-db: + image: pgvector/pgvector:pg17 + restart: unless-stopped + environment: + POSTGRES_USER: memory + POSTGRES_PASSWORD: ${MEMORY_DB_PASSWORD:-memory} + POSTGRES_DB: memories + volumes: + - memory-pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U memory -d memories"] + interval: 5s + timeout: 3s + retries: 5 + + memory-service: + build: ./memory-service + restart: unless-stopped + environment: + DATABASE_URL: postgresql://memory:${MEMORY_DB_PASSWORD:-memory}@memory-db:5432/memories + LITELLM_BASE_URL: ${LITELLM_BASE_URL} + LITELLM_API_KEY: ${LITELLM_API_KEY:-not-needed} + EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small} + depends_on: + memory-db: + condition: service_healthy + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8090/health')"] + interval: 10s + timeout: 5s + retries: 3 volumes: bot-data: + memory-pgdata: diff --git a/memory-service/Dockerfile b/memory-service/Dockerfile new file mode 100644 index 0000000..6df5823 --- /dev/null +++ b/memory-service/Dockerfile @@ -0,0 +1,6 @@ +FROM python:3.11-slim +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt +COPY main.py . +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8090"] diff --git a/memory-service/main.py b/memory-service/main.py new file mode 100644 index 0000000..9b107ac --- /dev/null +++ b/memory-service/main.py @@ -0,0 +1,191 @@ +import os +import logging +import time + +import asyncpg +import httpx +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +logger = logging.getLogger("memory-service") +logging.basicConfig(level=logging.INFO) + +DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") +LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") +LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") +EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") +EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) +DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) + +app = FastAPI(title="Memory Service") +pool: asyncpg.Pool | None = None + + +class StoreRequest(BaseModel): + user_id: str + fact: str + source_room: str = "" + + +class QueryRequest(BaseModel): + user_id: str + query: str + top_k: int = 10 + + +async def _embed(text: str) -> list[float]: + """Get embedding vector from LiteLLM /embeddings endpoint.""" + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{LITELLM_URL}/embeddings", + json={"model": EMBED_MODEL, "input": text}, + headers={"Authorization": f"Bearer {LITELLM_KEY}"}, + ) + resp.raise_for_status() + return resp.json()["data"][0]["embedding"] + + +async def _init_db(): + """Create pgvector extension and memories table if not exists.""" + global pool + pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) + async with pool.acquire() as conn: + await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS memories ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + fact TEXT NOT NULL, + source_room TEXT DEFAULT '', + created_at DOUBLE PRECISION NOT NULL, + embedding vector({EMBED_DIMS}) NOT NULL + ) + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id) + """) + await conn.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_embedding + ON memories USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 10) + """) + logger.info("Database initialized (dims=%d)", EMBED_DIMS) + + +@app.on_event("startup") +async def startup(): + await _init_db() + + +@app.on_event("shutdown") +async def shutdown(): + if pool: + await pool.close() + + +@app.get("/health") +async def health(): + if pool: + async with pool.acquire() as conn: + count = await conn.fetchval("SELECT count(*) FROM memories") + return {"status": "ok", "total_memories": count} + return {"status": "no_db"} + + +@app.post("/memories/store") +async def store_memory(req: StoreRequest): + """Embed fact, deduplicate by cosine similarity, insert.""" + if not req.fact.strip(): + raise HTTPException(400, "Empty fact") + + embedding = await _embed(req.fact) + vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" + + async with pool.acquire() as conn: + # Check for duplicates (cosine similarity > threshold) + dup = await conn.fetchval( + """ + SELECT id FROM memories + WHERE user_id = $1 + AND 1 - (embedding <=> $2::vector) > $3 + LIMIT 1 + """, + req.user_id, vec_literal, DEDUP_THRESHOLD, + ) + if dup: + logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup) + return {"stored": False, "reason": "duplicate"} + + await conn.execute( + """ + INSERT INTO memories (user_id, fact, source_room, created_at, embedding) + VALUES ($1, $2, $3, $4, $5::vector) + """, + req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal, + ) + logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60]) + return {"stored": True} + + +@app.post("/memories/query") +async def query_memories(req: QueryRequest): + """Embed query, return top-K similar facts for user.""" + embedding = await _embed(req.query) + vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" + + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT fact, source_room, created_at, + 1 - (embedding <=> $1::vector) AS similarity + FROM memories + WHERE user_id = $2 + ORDER BY embedding <=> $1::vector + LIMIT $3 + """, + vec_literal, req.user_id, req.top_k, + ) + + results = [ + { + "fact": r["fact"], + "source_room": r["source_room"], + "created_at": r["created_at"], + "similarity": float(r["similarity"]), + } + for r in rows + ] + return {"results": results} + + +@app.delete("/memories/{user_id}") +async def delete_user_memories(user_id: str): + """GDPR delete — remove all memories for a user.""" + async with pool.acquire() as conn: + result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id) + count = int(result.split()[-1]) + logger.info("Deleted %d memories for %s", count, user_id) + return {"deleted": count} + + +@app.get("/memories/{user_id}") +async def list_user_memories(user_id: str): + """List all memories for a user (for UI/debug).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT fact, source_room, created_at + FROM memories + WHERE user_id = $1 + ORDER BY created_at DESC + """, + user_id, + ) + return { + "user_id": user_id, + "count": len(rows), + "memories": [ + {"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]} + for r in rows + ], + } diff --git a/memory-service/migrate_json.py b/memory-service/migrate_json.py new file mode 100644 index 0000000..d4de7d2 --- /dev/null +++ b/memory-service/migrate_json.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +"""One-time migration: read JSON memory files, embed each fact, insert into pgvector.""" + +import asyncio +import json +import logging +import os +import sys +import time + +import asyncpg +import httpx + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("migrate") + +DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") +LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") +LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") +EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") +MEMORIES_DIR = os.environ.get("MEMORIES_DIR", "/data/memories") + + +async def embed(text: str) -> list[float]: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{LITELLM_URL}/embeddings", + json={"model": EMBED_MODEL, "input": text}, + headers={"Authorization": f"Bearer {LITELLM_KEY}"}, + ) + resp.raise_for_status() + return resp.json()["data"][0]["embedding"] + + +async def main(): + if not os.path.isdir(MEMORIES_DIR): + logger.error("MEMORIES_DIR %s does not exist", MEMORIES_DIR) + sys.exit(1) + + json_files = [f for f in os.listdir(MEMORIES_DIR) if f.endswith(".json")] + if not json_files: + logger.info("No JSON memory files found in %s", MEMORIES_DIR) + return + + logger.info("Found %d memory files to migrate", len(json_files)) + + pool = await asyncpg.create_pool(DB_DSN, min_size=1, max_size=5) + + total_migrated = 0 + total_skipped = 0 + + for filename in json_files: + filepath = os.path.join(MEMORIES_DIR, filename) + try: + with open(filepath) as f: + memories = json.load(f) + except (json.JSONDecodeError, OSError) as e: + logger.warning("Skipping %s: %s", filename, e) + continue + + if not memories: + continue + + # The filename is a hash of the user_id — we need to find the user_id + # from the fact entries or use the hash as identifier. + # Since JSON files are named by sha256(user_id)[:16].json, we can't + # reverse the hash. We'll need to scan bot-data for user_keys.json + # to build a mapping, or just use the hash as user_id placeholder. + # + # Better approach: read all facts and check if any contain user identity. + # For now, use the filename hash as a temporary user_id marker. + # The bot will re-associate on next interaction. + user_hash = filename.replace(".json", "") + + for mem in memories: + fact = mem.get("fact", "").strip() + if not fact: + continue + + try: + embedding = await embed(fact) + except Exception as e: + logger.warning("Embedding failed for fact '%s': %s", fact[:50], e) + total_skipped += 1 + continue + + vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" + created_at = mem.get("created", time.time()) + source_room = mem.get("source_room", "") + + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO memories (user_id, fact, source_room, created_at, embedding) + VALUES ($1, $2, $3, $4, $5::vector) + """, + user_hash, fact, source_room, created_at, vec_literal, + ) + total_migrated += 1 + + logger.info("Migrated %s: %d facts", filename, len(memories)) + + await pool.close() + logger.info("Migration complete: %d migrated, %d skipped", total_migrated, total_skipped) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/memory-service/requirements.txt b/memory-service/requirements.txt new file mode 100644 index 0000000..5f17c58 --- /dev/null +++ b/memory-service/requirements.txt @@ -0,0 +1,5 @@ +fastapi>=0.115,<1.0 +uvicorn>=0.34,<1.0 +asyncpg>=0.30,<1.0 +pgvector>=0.3,<1.0 +httpx>=0.27,<1.0