feat: Replace JSON memory with pgvector semantic search (MAT-11)
Add memory-service (FastAPI + pgvector) for semantic memory storage. Bot now queries relevant memories per conversation instead of dumping all 50. Includes migration script for existing JSON files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
153
bot.py
153
bot.py
@@ -8,8 +8,6 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import hashlib
|
||||
|
||||
import fitz # pymupdf
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
@@ -60,8 +58,7 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
||||
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
|
||||
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
|
||||
USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json")
|
||||
MEMORIES_DIR = os.environ.get("MEMORIES_DIR", "/data/memories")
|
||||
MAX_MEMORIES_PER_USER = 50
|
||||
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
|
||||
|
||||
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
|
||||
Keep answers concise but thorough. Use markdown formatting when helpful.
|
||||
@@ -190,6 +187,65 @@ class DocumentRAG:
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
class MemoryClient:
|
||||
"""Async HTTP client for the memory-service."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.enabled = bool(base_url)
|
||||
|
||||
async def store(self, user_id: str, fact: str, source_room: str = ""):
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.post(
|
||||
f"{self.base_url}/memories/store",
|
||||
json={"user_id": user_id, "fact": fact, "source_room": source_room},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Memory store failed", exc_info=True)
|
||||
|
||||
async def query(self, user_id: str, query: str, top_k: int = 10) -> list[dict]:
|
||||
if not self.enabled:
|
||||
return []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/memories/query",
|
||||
json={"user_id": user_id, "query": query, "top_k": top_k},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("results", [])
|
||||
except Exception:
|
||||
logger.warning("Memory query failed", exc_info=True)
|
||||
return []
|
||||
|
||||
async def delete_user(self, user_id: str) -> int:
|
||||
if not self.enabled:
|
||||
return 0
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("deleted", 0)
|
||||
except Exception:
|
||||
logger.warning("Memory delete failed", exc_info=True)
|
||||
return 0
|
||||
|
||||
async def list_all(self, user_id: str) -> list[dict]:
|
||||
if not self.enabled:
|
||||
return []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/memories/{user_id}")
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("memories", [])
|
||||
except Exception:
|
||||
logger.warning("Memory list failed", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(self):
|
||||
config = AsyncClientConfig(
|
||||
@@ -208,6 +264,7 @@ class Bot:
|
||||
self.dispatched_rooms = set()
|
||||
self.active_calls = set() # rooms where we've sent call member event
|
||||
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
|
||||
self.memory = MemoryClient(MEMORY_SERVICE_URL)
|
||||
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
||||
self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key
|
||||
self.room_models: dict[str, str] = {} # room_id -> model name
|
||||
@@ -414,46 +471,22 @@ class Bot:
|
||||
|
||||
# --- User memory helpers ---
|
||||
|
||||
def _memory_path(self, user_id: str) -> str:
|
||||
"""Get the file path for a user's memory store."""
|
||||
uid_hash = hashlib.sha256(user_id.encode()).hexdigest()[:16]
|
||||
return os.path.join(MEMORIES_DIR, f"{uid_hash}.json")
|
||||
|
||||
def _load_memories(self, user_id: str) -> list[dict]:
|
||||
"""Load memories for a user. Returns list of {fact, created, source_room}."""
|
||||
path = self._memory_path(user_id)
|
||||
try:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
def _save_memories(self, user_id: str, memories: list[dict]):
|
||||
"""Save memories for a user, capping at MAX_MEMORIES_PER_USER."""
|
||||
os.makedirs(MEMORIES_DIR, exist_ok=True)
|
||||
# Keep only the most recent memories
|
||||
memories = memories[-MAX_MEMORIES_PER_USER:]
|
||||
path = self._memory_path(user_id)
|
||||
with open(path, "w") as f:
|
||||
json.dump(memories, f, indent=2)
|
||||
|
||||
def _format_memories(self, memories: list[dict]) -> str:
|
||||
"""Format memories as a system prompt section."""
|
||||
@staticmethod
|
||||
def _format_memories(memories: list[dict]) -> str:
|
||||
"""Format memory query results as a system prompt section."""
|
||||
if not memories:
|
||||
return ""
|
||||
facts = [m["fact"] for m in memories]
|
||||
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
|
||||
|
||||
async def _extract_memories(self, user_message: str, ai_reply: str,
|
||||
existing: list[dict], model: str,
|
||||
sender: str, room_id: str) -> list[dict]:
|
||||
"""Use LLM to extract memorable facts from the conversation, deduplicate with existing."""
|
||||
async def _extract_and_store_memories(self, user_message: str, ai_reply: str,
|
||||
existing_facts: list[str], model: str,
|
||||
sender: str, room_id: str):
|
||||
"""Use LLM to extract memorable facts, then store each via memory-service."""
|
||||
if not self.llm:
|
||||
return existing
|
||||
return
|
||||
|
||||
existing_facts = [m["fact"] for m in existing]
|
||||
existing_text = "\n".join(f"- {f}" for f in existing_facts) if existing_facts else "(none)"
|
||||
|
||||
logger.info("Memory extraction: user_msg=%s... (%d existing facts)", user_message[:80], len(existing_facts))
|
||||
|
||||
try:
|
||||
@@ -481,7 +514,6 @@ class Bot:
|
||||
raw = resp.choices[0].message.content.strip()
|
||||
logger.info("Memory extraction raw response: %s", raw[:200])
|
||||
|
||||
# Robust JSON extraction: strip markdown fences, find array
|
||||
if raw.startswith("```"):
|
||||
raw = re.sub(r"^```\w*\n?", "", raw)
|
||||
raw = re.sub(r"\n?```$", "", raw)
|
||||
@@ -491,22 +523,17 @@ class Bot:
|
||||
new_facts = json.loads(raw)
|
||||
if not isinstance(new_facts, list):
|
||||
logger.warning("Memory extraction returned non-list: %s", type(new_facts))
|
||||
return existing
|
||||
return
|
||||
|
||||
logger.info("Memory extraction found %d new facts", len(new_facts))
|
||||
|
||||
now = time.time()
|
||||
for fact in new_facts:
|
||||
if isinstance(fact, str) and fact.strip():
|
||||
existing.append({"fact": fact.strip(), "created": now, "source_room": room_id})
|
||||
await self.memory.store(sender, fact.strip(), room_id)
|
||||
|
||||
return existing
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Memory extraction JSON parse failed, raw: %s", raw[:200])
|
||||
return existing
|
||||
except Exception:
|
||||
logger.warning("Memory extraction failed", exc_info=True)
|
||||
return existing
|
||||
|
||||
async def _detect_language(self, text: str) -> str:
|
||||
"""Detect the language of a text using a fast LLM call."""
|
||||
@@ -544,9 +571,9 @@ class Bot:
|
||||
logger.debug("Translation failed", exc_info=True)
|
||||
return f"[Translation failed] {text}"
|
||||
|
||||
def _get_preferred_language(self, user_id: str) -> str:
|
||||
async def _get_preferred_language(self, user_id: str) -> str:
|
||||
"""Get user's preferred language from memories (last match = most recent)."""
|
||||
memories = self._load_memories(user_id)
|
||||
memories = await self.memory.query(user_id, "preferred language", top_k=5)
|
||||
known_langs = [
|
||||
"English", "German", "French", "Spanish", "Italian", "Portuguese",
|
||||
"Dutch", "Russian", "Chinese", "Japanese", "Korean", "Arabic",
|
||||
@@ -620,7 +647,7 @@ class Bot:
|
||||
if is_dm and sender in self._pending_translate:
|
||||
pending = self._pending_translate.pop(sender)
|
||||
choice = body.strip().lower()
|
||||
preferred_lang = self._get_preferred_language(sender)
|
||||
preferred_lang = await self._get_preferred_language(sender)
|
||||
|
||||
if choice in ("1", "1️⃣") or choice.startswith("translate"):
|
||||
await self.client.room_typing(room.room_id, typing_state=True)
|
||||
@@ -653,7 +680,7 @@ class Bot:
|
||||
|
||||
# --- DM translation workflow: detect foreign language ---
|
||||
if is_dm and not body.startswith("!ai") and not image_data:
|
||||
preferred_lang = self._get_preferred_language(sender)
|
||||
preferred_lang = await self._get_preferred_language(sender)
|
||||
detected_lang = await self._detect_language(body)
|
||||
logger.info("Translation check: detected=%s, preferred=%s, len=%d", detected_lang, preferred_lang, len(body))
|
||||
if (
|
||||
@@ -952,22 +979,17 @@ class Bot:
|
||||
elif cmd == "forget":
|
||||
sender = event.sender if event else None
|
||||
if sender:
|
||||
path = self._memory_path(sender)
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# Clear any in-memory caches for this user
|
||||
deleted = await self.memory.delete_user(sender)
|
||||
self._pending_translate.pop(sender, None)
|
||||
self._pending_reply.pop(sender, None)
|
||||
await self._send_text(room.room_id, "All my memories about you have been deleted.")
|
||||
await self._send_text(room.room_id, f"All my memories about you have been deleted ({deleted} facts removed).")
|
||||
else:
|
||||
await self._send_text(room.room_id, "Could not identify user.")
|
||||
|
||||
elif cmd == "memories":
|
||||
sender = event.sender if event else None
|
||||
if sender:
|
||||
memories = self._load_memories(sender)
|
||||
memories = await self.memory.list_all(sender)
|
||||
if memories:
|
||||
text = f"**I remember {len(memories)} things about you:**\n"
|
||||
text += "\n".join(f"- {m['fact']}" for m in memories)
|
||||
@@ -1148,8 +1170,8 @@ class Bot:
|
||||
else:
|
||||
logger.info("RAG found 0 docs for: %s (original: %s)", search_query[:50], user_message[:50])
|
||||
|
||||
# Load user memories
|
||||
memories = self._load_memories(sender) if sender else []
|
||||
# Query relevant memories via semantic search
|
||||
memories = await self.memory.query(sender, user_message, top_k=10) if sender else []
|
||||
memory_context = self._format_memories(memories)
|
||||
|
||||
# Build conversation context
|
||||
@@ -1191,21 +1213,16 @@ class Bot:
|
||||
else:
|
||||
await self._send_text(room.room_id, reply)
|
||||
|
||||
# Extract and save new memories (after reply sent, with timeout)
|
||||
# Extract and store new memories (after reply sent, with timeout)
|
||||
if sender and reply:
|
||||
existing_facts = [m["fact"] for m in memories]
|
||||
try:
|
||||
updated = await asyncio.wait_for(
|
||||
self._extract_memories(
|
||||
user_message, reply, memories, model, sender, room.room_id
|
||||
await asyncio.wait_for(
|
||||
self._extract_and_store_memories(
|
||||
user_message, reply, existing_facts, model, sender, room.room_id
|
||||
),
|
||||
timeout=15.0,
|
||||
)
|
||||
if len(updated) > len(memories):
|
||||
self._save_memories(sender, updated)
|
||||
logger.info("Saved %d new memories for %s (total: %d)",
|
||||
len(updated) - len(memories), sender, len(updated))
|
||||
else:
|
||||
logger.info("No new memories extracted for %s", sender)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Memory extraction timed out for %s", sender)
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user