feat: Replace JSON memory with pgvector semantic search (MAT-11)

Add memory-service (FastAPI + pgvector) for semantic memory storage.
Bot now queries relevant memories per conversation instead of dumping all 50.
Includes migration script for existing JSON files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Christian Gick
2026-02-20 06:25:50 +02:00
parent 0c674f1467
commit 4cd7a0262e
6 changed files with 432 additions and 68 deletions

153
bot.py
View File

@@ -8,8 +8,6 @@ import re
import time
import uuid
import hashlib
import fitz # pymupdf
import httpx
from openai import AsyncOpenAI
@@ -60,8 +58,7 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json")
MEMORIES_DIR = os.environ.get("MEMORIES_DIR", "/data/memories")
MAX_MEMORIES_PER_USER = 50
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
Keep answers concise but thorough. Use markdown formatting when helpful.
@@ -190,6 +187,65 @@ class DocumentRAG:
return "\n".join(parts)
class MemoryClient:
"""Async HTTP client for the memory-service."""
def __init__(self, base_url: str):
self.base_url = base_url.rstrip("/")
self.enabled = bool(base_url)
async def store(self, user_id: str, fact: str, source_room: str = ""):
if not self.enabled:
return
try:
async with httpx.AsyncClient(timeout=10.0) as client:
await client.post(
f"{self.base_url}/memories/store",
json={"user_id": user_id, "fact": fact, "source_room": source_room},
)
except Exception:
logger.warning("Memory store failed", exc_info=True)
async def query(self, user_id: str, query: str, top_k: int = 10) -> list[dict]:
if not self.enabled:
return []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{self.base_url}/memories/query",
json={"user_id": user_id, "query": query, "top_k": top_k},
)
resp.raise_for_status()
return resp.json().get("results", [])
except Exception:
logger.warning("Memory query failed", exc_info=True)
return []
async def delete_user(self, user_id: str) -> int:
if not self.enabled:
return 0
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
resp.raise_for_status()
return resp.json().get("deleted", 0)
except Exception:
logger.warning("Memory delete failed", exc_info=True)
return 0
async def list_all(self, user_id: str) -> list[dict]:
if not self.enabled:
return []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/memories/{user_id}")
resp.raise_for_status()
return resp.json().get("memories", [])
except Exception:
logger.warning("Memory list failed", exc_info=True)
return []
class Bot:
def __init__(self):
config = AsyncClientConfig(
@@ -208,6 +264,7 @@ class Bot:
self.dispatched_rooms = set()
self.active_calls = set() # rooms where we've sent call member event
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
self.memory = MemoryClient(MEMORY_SERVICE_URL)
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key
self.room_models: dict[str, str] = {} # room_id -> model name
@@ -414,46 +471,22 @@ class Bot:
# --- User memory helpers ---
def _memory_path(self, user_id: str) -> str:
"""Get the file path for a user's memory store."""
uid_hash = hashlib.sha256(user_id.encode()).hexdigest()[:16]
return os.path.join(MEMORIES_DIR, f"{uid_hash}.json")
def _load_memories(self, user_id: str) -> list[dict]:
"""Load memories for a user. Returns list of {fact, created, source_room}."""
path = self._memory_path(user_id)
try:
with open(path) as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return []
def _save_memories(self, user_id: str, memories: list[dict]):
"""Save memories for a user, capping at MAX_MEMORIES_PER_USER."""
os.makedirs(MEMORIES_DIR, exist_ok=True)
# Keep only the most recent memories
memories = memories[-MAX_MEMORIES_PER_USER:]
path = self._memory_path(user_id)
with open(path, "w") as f:
json.dump(memories, f, indent=2)
def _format_memories(self, memories: list[dict]) -> str:
"""Format memories as a system prompt section."""
@staticmethod
def _format_memories(memories: list[dict]) -> str:
"""Format memory query results as a system prompt section."""
if not memories:
return ""
facts = [m["fact"] for m in memories]
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
async def _extract_memories(self, user_message: str, ai_reply: str,
existing: list[dict], model: str,
sender: str, room_id: str) -> list[dict]:
"""Use LLM to extract memorable facts from the conversation, deduplicate with existing."""
async def _extract_and_store_memories(self, user_message: str, ai_reply: str,
existing_facts: list[str], model: str,
sender: str, room_id: str):
"""Use LLM to extract memorable facts, then store each via memory-service."""
if not self.llm:
return existing
return
existing_facts = [m["fact"] for m in existing]
existing_text = "\n".join(f"- {f}" for f in existing_facts) if existing_facts else "(none)"
logger.info("Memory extraction: user_msg=%s... (%d existing facts)", user_message[:80], len(existing_facts))
try:
@@ -481,7 +514,6 @@ class Bot:
raw = resp.choices[0].message.content.strip()
logger.info("Memory extraction raw response: %s", raw[:200])
# Robust JSON extraction: strip markdown fences, find array
if raw.startswith("```"):
raw = re.sub(r"^```\w*\n?", "", raw)
raw = re.sub(r"\n?```$", "", raw)
@@ -491,22 +523,17 @@ class Bot:
new_facts = json.loads(raw)
if not isinstance(new_facts, list):
logger.warning("Memory extraction returned non-list: %s", type(new_facts))
return existing
return
logger.info("Memory extraction found %d new facts", len(new_facts))
now = time.time()
for fact in new_facts:
if isinstance(fact, str) and fact.strip():
existing.append({"fact": fact.strip(), "created": now, "source_room": room_id})
await self.memory.store(sender, fact.strip(), room_id)
return existing
except json.JSONDecodeError:
logger.warning("Memory extraction JSON parse failed, raw: %s", raw[:200])
return existing
except Exception:
logger.warning("Memory extraction failed", exc_info=True)
return existing
async def _detect_language(self, text: str) -> str:
"""Detect the language of a text using a fast LLM call."""
@@ -544,9 +571,9 @@ class Bot:
logger.debug("Translation failed", exc_info=True)
return f"[Translation failed] {text}"
def _get_preferred_language(self, user_id: str) -> str:
async def _get_preferred_language(self, user_id: str) -> str:
"""Get user's preferred language from memories (last match = most recent)."""
memories = self._load_memories(user_id)
memories = await self.memory.query(user_id, "preferred language", top_k=5)
known_langs = [
"English", "German", "French", "Spanish", "Italian", "Portuguese",
"Dutch", "Russian", "Chinese", "Japanese", "Korean", "Arabic",
@@ -620,7 +647,7 @@ class Bot:
if is_dm and sender in self._pending_translate:
pending = self._pending_translate.pop(sender)
choice = body.strip().lower()
preferred_lang = self._get_preferred_language(sender)
preferred_lang = await self._get_preferred_language(sender)
if choice in ("1", "1") or choice.startswith("translate"):
await self.client.room_typing(room.room_id, typing_state=True)
@@ -653,7 +680,7 @@ class Bot:
# --- DM translation workflow: detect foreign language ---
if is_dm and not body.startswith("!ai") and not image_data:
preferred_lang = self._get_preferred_language(sender)
preferred_lang = await self._get_preferred_language(sender)
detected_lang = await self._detect_language(body)
logger.info("Translation check: detected=%s, preferred=%s, len=%d", detected_lang, preferred_lang, len(body))
if (
@@ -952,22 +979,17 @@ class Bot:
elif cmd == "forget":
sender = event.sender if event else None
if sender:
path = self._memory_path(sender)
try:
os.remove(path)
except FileNotFoundError:
pass
# Clear any in-memory caches for this user
deleted = await self.memory.delete_user(sender)
self._pending_translate.pop(sender, None)
self._pending_reply.pop(sender, None)
await self._send_text(room.room_id, "All my memories about you have been deleted.")
await self._send_text(room.room_id, f"All my memories about you have been deleted ({deleted} facts removed).")
else:
await self._send_text(room.room_id, "Could not identify user.")
elif cmd == "memories":
sender = event.sender if event else None
if sender:
memories = self._load_memories(sender)
memories = await self.memory.list_all(sender)
if memories:
text = f"**I remember {len(memories)} things about you:**\n"
text += "\n".join(f"- {m['fact']}" for m in memories)
@@ -1148,8 +1170,8 @@ class Bot:
else:
logger.info("RAG found 0 docs for: %s (original: %s)", search_query[:50], user_message[:50])
# Load user memories
memories = self._load_memories(sender) if sender else []
# Query relevant memories via semantic search
memories = await self.memory.query(sender, user_message, top_k=10) if sender else []
memory_context = self._format_memories(memories)
# Build conversation context
@@ -1191,21 +1213,16 @@ class Bot:
else:
await self._send_text(room.room_id, reply)
# Extract and save new memories (after reply sent, with timeout)
# Extract and store new memories (after reply sent, with timeout)
if sender and reply:
existing_facts = [m["fact"] for m in memories]
try:
updated = await asyncio.wait_for(
self._extract_memories(
user_message, reply, memories, model, sender, room.room_id
await asyncio.wait_for(
self._extract_and_store_memories(
user_message, reply, existing_facts, model, sender, room.room_id
),
timeout=15.0,
)
if len(updated) > len(memories):
self._save_memories(sender, updated)
logger.info("Saved %d new memories for %s (total: %d)",
len(updated) - len(memories), sender, len(updated))
else:
logger.info("No new memories extracted for %s", sender)
except asyncio.TimeoutError:
logger.warning("Memory extraction timed out for %s", sender)
except Exception: