feat: Replace JSON memory with pgvector semantic search (MAT-11)
Add memory-service (FastAPI + pgvector) for semantic memory storage. Bot now queries relevant memories per conversation instead of dumping all 50. Includes migration script for existing JSON files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
153
bot.py
153
bot.py
@@ -8,8 +8,6 @@ import re
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
import fitz # pymupdf
|
import fitz # pymupdf
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
@@ -60,8 +58,7 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
|||||||
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
|
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
|
||||||
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
|
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
|
||||||
USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json")
|
USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json")
|
||||||
MEMORIES_DIR = os.environ.get("MEMORIES_DIR", "/data/memories")
|
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
|
||||||
MAX_MEMORIES_PER_USER = 50
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
|
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
|
||||||
Keep answers concise but thorough. Use markdown formatting when helpful.
|
Keep answers concise but thorough. Use markdown formatting when helpful.
|
||||||
@@ -190,6 +187,65 @@ class DocumentRAG:
|
|||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryClient:
|
||||||
|
"""Async HTTP client for the memory-service."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.enabled = bool(base_url)
|
||||||
|
|
||||||
|
async def store(self, user_id: str, fact: str, source_room: str = ""):
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
await client.post(
|
||||||
|
f"{self.base_url}/memories/store",
|
||||||
|
json={"user_id": user_id, "fact": fact, "source_room": source_room},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Memory store failed", exc_info=True)
|
||||||
|
|
||||||
|
async def query(self, user_id: str, query: str, top_k: int = 10) -> list[dict]:
|
||||||
|
if not self.enabled:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{self.base_url}/memories/query",
|
||||||
|
json={"user_id": user_id, "query": query, "top_k": top_k},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("results", [])
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Memory query failed", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def delete_user(self, user_id: str) -> int:
|
||||||
|
if not self.enabled:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("deleted", 0)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Memory delete failed", exc_info=True)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def list_all(self, user_id: str) -> list[dict]:
|
||||||
|
if not self.enabled:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.get(f"{self.base_url}/memories/{user_id}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("memories", [])
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Memory list failed", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class Bot:
|
class Bot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = AsyncClientConfig(
|
config = AsyncClientConfig(
|
||||||
@@ -208,6 +264,7 @@ class Bot:
|
|||||||
self.dispatched_rooms = set()
|
self.dispatched_rooms = set()
|
||||||
self.active_calls = set() # rooms where we've sent call member event
|
self.active_calls = set() # rooms where we've sent call member event
|
||||||
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
|
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
|
||||||
|
self.memory = MemoryClient(MEMORY_SERVICE_URL)
|
||||||
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
||||||
self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key
|
self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key
|
||||||
self.room_models: dict[str, str] = {} # room_id -> model name
|
self.room_models: dict[str, str] = {} # room_id -> model name
|
||||||
@@ -414,46 +471,22 @@ class Bot:
|
|||||||
|
|
||||||
# --- User memory helpers ---
|
# --- User memory helpers ---
|
||||||
|
|
||||||
def _memory_path(self, user_id: str) -> str:
|
@staticmethod
|
||||||
"""Get the file path for a user's memory store."""
|
def _format_memories(memories: list[dict]) -> str:
|
||||||
uid_hash = hashlib.sha256(user_id.encode()).hexdigest()[:16]
|
"""Format memory query results as a system prompt section."""
|
||||||
return os.path.join(MEMORIES_DIR, f"{uid_hash}.json")
|
|
||||||
|
|
||||||
def _load_memories(self, user_id: str) -> list[dict]:
|
|
||||||
"""Load memories for a user. Returns list of {fact, created, source_room}."""
|
|
||||||
path = self._memory_path(user_id)
|
|
||||||
try:
|
|
||||||
with open(path) as f:
|
|
||||||
return json.load(f)
|
|
||||||
except (FileNotFoundError, json.JSONDecodeError):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _save_memories(self, user_id: str, memories: list[dict]):
|
|
||||||
"""Save memories for a user, capping at MAX_MEMORIES_PER_USER."""
|
|
||||||
os.makedirs(MEMORIES_DIR, exist_ok=True)
|
|
||||||
# Keep only the most recent memories
|
|
||||||
memories = memories[-MAX_MEMORIES_PER_USER:]
|
|
||||||
path = self._memory_path(user_id)
|
|
||||||
with open(path, "w") as f:
|
|
||||||
json.dump(memories, f, indent=2)
|
|
||||||
|
|
||||||
def _format_memories(self, memories: list[dict]) -> str:
|
|
||||||
"""Format memories as a system prompt section."""
|
|
||||||
if not memories:
|
if not memories:
|
||||||
return ""
|
return ""
|
||||||
facts = [m["fact"] for m in memories]
|
facts = [m["fact"] for m in memories]
|
||||||
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
|
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
|
||||||
|
|
||||||
async def _extract_memories(self, user_message: str, ai_reply: str,
|
async def _extract_and_store_memories(self, user_message: str, ai_reply: str,
|
||||||
existing: list[dict], model: str,
|
existing_facts: list[str], model: str,
|
||||||
sender: str, room_id: str) -> list[dict]:
|
sender: str, room_id: str):
|
||||||
"""Use LLM to extract memorable facts from the conversation, deduplicate with existing."""
|
"""Use LLM to extract memorable facts, then store each via memory-service."""
|
||||||
if not self.llm:
|
if not self.llm:
|
||||||
return existing
|
return
|
||||||
|
|
||||||
existing_facts = [m["fact"] for m in existing]
|
|
||||||
existing_text = "\n".join(f"- {f}" for f in existing_facts) if existing_facts else "(none)"
|
existing_text = "\n".join(f"- {f}" for f in existing_facts) if existing_facts else "(none)"
|
||||||
|
|
||||||
logger.info("Memory extraction: user_msg=%s... (%d existing facts)", user_message[:80], len(existing_facts))
|
logger.info("Memory extraction: user_msg=%s... (%d existing facts)", user_message[:80], len(existing_facts))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -481,7 +514,6 @@ class Bot:
|
|||||||
raw = resp.choices[0].message.content.strip()
|
raw = resp.choices[0].message.content.strip()
|
||||||
logger.info("Memory extraction raw response: %s", raw[:200])
|
logger.info("Memory extraction raw response: %s", raw[:200])
|
||||||
|
|
||||||
# Robust JSON extraction: strip markdown fences, find array
|
|
||||||
if raw.startswith("```"):
|
if raw.startswith("```"):
|
||||||
raw = re.sub(r"^```\w*\n?", "", raw)
|
raw = re.sub(r"^```\w*\n?", "", raw)
|
||||||
raw = re.sub(r"\n?```$", "", raw)
|
raw = re.sub(r"\n?```$", "", raw)
|
||||||
@@ -491,22 +523,17 @@ class Bot:
|
|||||||
new_facts = json.loads(raw)
|
new_facts = json.loads(raw)
|
||||||
if not isinstance(new_facts, list):
|
if not isinstance(new_facts, list):
|
||||||
logger.warning("Memory extraction returned non-list: %s", type(new_facts))
|
logger.warning("Memory extraction returned non-list: %s", type(new_facts))
|
||||||
return existing
|
return
|
||||||
|
|
||||||
logger.info("Memory extraction found %d new facts", len(new_facts))
|
logger.info("Memory extraction found %d new facts", len(new_facts))
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
for fact in new_facts:
|
for fact in new_facts:
|
||||||
if isinstance(fact, str) and fact.strip():
|
if isinstance(fact, str) and fact.strip():
|
||||||
existing.append({"fact": fact.strip(), "created": now, "source_room": room_id})
|
await self.memory.store(sender, fact.strip(), room_id)
|
||||||
|
|
||||||
return existing
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning("Memory extraction JSON parse failed, raw: %s", raw[:200])
|
logger.warning("Memory extraction JSON parse failed, raw: %s", raw[:200])
|
||||||
return existing
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Memory extraction failed", exc_info=True)
|
logger.warning("Memory extraction failed", exc_info=True)
|
||||||
return existing
|
|
||||||
|
|
||||||
async def _detect_language(self, text: str) -> str:
|
async def _detect_language(self, text: str) -> str:
|
||||||
"""Detect the language of a text using a fast LLM call."""
|
"""Detect the language of a text using a fast LLM call."""
|
||||||
@@ -544,9 +571,9 @@ class Bot:
|
|||||||
logger.debug("Translation failed", exc_info=True)
|
logger.debug("Translation failed", exc_info=True)
|
||||||
return f"[Translation failed] {text}"
|
return f"[Translation failed] {text}"
|
||||||
|
|
||||||
def _get_preferred_language(self, user_id: str) -> str:
|
async def _get_preferred_language(self, user_id: str) -> str:
|
||||||
"""Get user's preferred language from memories (last match = most recent)."""
|
"""Get user's preferred language from memories (last match = most recent)."""
|
||||||
memories = self._load_memories(user_id)
|
memories = await self.memory.query(user_id, "preferred language", top_k=5)
|
||||||
known_langs = [
|
known_langs = [
|
||||||
"English", "German", "French", "Spanish", "Italian", "Portuguese",
|
"English", "German", "French", "Spanish", "Italian", "Portuguese",
|
||||||
"Dutch", "Russian", "Chinese", "Japanese", "Korean", "Arabic",
|
"Dutch", "Russian", "Chinese", "Japanese", "Korean", "Arabic",
|
||||||
@@ -620,7 +647,7 @@ class Bot:
|
|||||||
if is_dm and sender in self._pending_translate:
|
if is_dm and sender in self._pending_translate:
|
||||||
pending = self._pending_translate.pop(sender)
|
pending = self._pending_translate.pop(sender)
|
||||||
choice = body.strip().lower()
|
choice = body.strip().lower()
|
||||||
preferred_lang = self._get_preferred_language(sender)
|
preferred_lang = await self._get_preferred_language(sender)
|
||||||
|
|
||||||
if choice in ("1", "1️⃣") or choice.startswith("translate"):
|
if choice in ("1", "1️⃣") or choice.startswith("translate"):
|
||||||
await self.client.room_typing(room.room_id, typing_state=True)
|
await self.client.room_typing(room.room_id, typing_state=True)
|
||||||
@@ -653,7 +680,7 @@ class Bot:
|
|||||||
|
|
||||||
# --- DM translation workflow: detect foreign language ---
|
# --- DM translation workflow: detect foreign language ---
|
||||||
if is_dm and not body.startswith("!ai") and not image_data:
|
if is_dm and not body.startswith("!ai") and not image_data:
|
||||||
preferred_lang = self._get_preferred_language(sender)
|
preferred_lang = await self._get_preferred_language(sender)
|
||||||
detected_lang = await self._detect_language(body)
|
detected_lang = await self._detect_language(body)
|
||||||
logger.info("Translation check: detected=%s, preferred=%s, len=%d", detected_lang, preferred_lang, len(body))
|
logger.info("Translation check: detected=%s, preferred=%s, len=%d", detected_lang, preferred_lang, len(body))
|
||||||
if (
|
if (
|
||||||
@@ -952,22 +979,17 @@ class Bot:
|
|||||||
elif cmd == "forget":
|
elif cmd == "forget":
|
||||||
sender = event.sender if event else None
|
sender = event.sender if event else None
|
||||||
if sender:
|
if sender:
|
||||||
path = self._memory_path(sender)
|
deleted = await self.memory.delete_user(sender)
|
||||||
try:
|
|
||||||
os.remove(path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
# Clear any in-memory caches for this user
|
|
||||||
self._pending_translate.pop(sender, None)
|
self._pending_translate.pop(sender, None)
|
||||||
self._pending_reply.pop(sender, None)
|
self._pending_reply.pop(sender, None)
|
||||||
await self._send_text(room.room_id, "All my memories about you have been deleted.")
|
await self._send_text(room.room_id, f"All my memories about you have been deleted ({deleted} facts removed).")
|
||||||
else:
|
else:
|
||||||
await self._send_text(room.room_id, "Could not identify user.")
|
await self._send_text(room.room_id, "Could not identify user.")
|
||||||
|
|
||||||
elif cmd == "memories":
|
elif cmd == "memories":
|
||||||
sender = event.sender if event else None
|
sender = event.sender if event else None
|
||||||
if sender:
|
if sender:
|
||||||
memories = self._load_memories(sender)
|
memories = await self.memory.list_all(sender)
|
||||||
if memories:
|
if memories:
|
||||||
text = f"**I remember {len(memories)} things about you:**\n"
|
text = f"**I remember {len(memories)} things about you:**\n"
|
||||||
text += "\n".join(f"- {m['fact']}" for m in memories)
|
text += "\n".join(f"- {m['fact']}" for m in memories)
|
||||||
@@ -1148,8 +1170,8 @@ class Bot:
|
|||||||
else:
|
else:
|
||||||
logger.info("RAG found 0 docs for: %s (original: %s)", search_query[:50], user_message[:50])
|
logger.info("RAG found 0 docs for: %s (original: %s)", search_query[:50], user_message[:50])
|
||||||
|
|
||||||
# Load user memories
|
# Query relevant memories via semantic search
|
||||||
memories = self._load_memories(sender) if sender else []
|
memories = await self.memory.query(sender, user_message, top_k=10) if sender else []
|
||||||
memory_context = self._format_memories(memories)
|
memory_context = self._format_memories(memories)
|
||||||
|
|
||||||
# Build conversation context
|
# Build conversation context
|
||||||
@@ -1191,21 +1213,16 @@ class Bot:
|
|||||||
else:
|
else:
|
||||||
await self._send_text(room.room_id, reply)
|
await self._send_text(room.room_id, reply)
|
||||||
|
|
||||||
# Extract and save new memories (after reply sent, with timeout)
|
# Extract and store new memories (after reply sent, with timeout)
|
||||||
if sender and reply:
|
if sender and reply:
|
||||||
|
existing_facts = [m["fact"] for m in memories]
|
||||||
try:
|
try:
|
||||||
updated = await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
self._extract_memories(
|
self._extract_and_store_memories(
|
||||||
user_message, reply, memories, model, sender, room.room_id
|
user_message, reply, existing_facts, model, sender, room.room_id
|
||||||
),
|
),
|
||||||
timeout=15.0,
|
timeout=15.0,
|
||||||
)
|
)
|
||||||
if len(updated) > len(memories):
|
|
||||||
self._save_memories(sender, updated)
|
|
||||||
logger.info("Saved %d new memories for %s (total: %d)",
|
|
||||||
len(updated) - len(memories), sender, len(updated))
|
|
||||||
else:
|
|
||||||
logger.info("No new memories extracted for %s", sender)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning("Memory extraction timed out for %s", sender)
|
logger.warning("Memory extraction timed out for %s", sender)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -17,8 +17,45 @@ services:
|
|||||||
- DEFAULT_MODEL
|
- DEFAULT_MODEL
|
||||||
- WILDFILES_BASE_URL
|
- WILDFILES_BASE_URL
|
||||||
- WILDFILES_ORG
|
- WILDFILES_ORG
|
||||||
|
- MEMORY_SERVICE_URL=http://memory-service:8090
|
||||||
volumes:
|
volumes:
|
||||||
- bot-data:/data
|
- bot-data:/data
|
||||||
|
depends_on:
|
||||||
|
memory-service:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
memory-db:
|
||||||
|
image: pgvector/pgvector:pg17
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: memory
|
||||||
|
POSTGRES_PASSWORD: ${MEMORY_DB_PASSWORD:-memory}
|
||||||
|
POSTGRES_DB: memories
|
||||||
|
volumes:
|
||||||
|
- memory-pgdata:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U memory -d memories"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
memory-service:
|
||||||
|
build: ./memory-service
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql://memory:${MEMORY_DB_PASSWORD:-memory}@memory-db:5432/memories
|
||||||
|
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
||||||
|
LITELLM_API_KEY: ${LITELLM_API_KEY:-not-needed}
|
||||||
|
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
||||||
|
depends_on:
|
||||||
|
memory-db:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8090/health')"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
bot-data:
|
bot-data:
|
||||||
|
memory-pgdata:
|
||||||
|
|||||||
6
memory-service/Dockerfile
Normal file
6
memory-service/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
WORKDIR /app
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
COPY main.py .
|
||||||
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||||
191
memory-service/main.py
Normal file
191
memory-service/main.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger("memory-service")
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
|
||||||
|
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||||
|
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||||
|
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
||||||
|
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
||||||
|
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
||||||
|
|
||||||
|
app = FastAPI(title="Memory Service")
|
||||||
|
pool: asyncpg.Pool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StoreRequest(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
fact: str
|
||||||
|
source_room: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
query: str
|
||||||
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
async def _embed(text: str) -> list[float]:
|
||||||
|
"""Get embedding vector from LiteLLM /embeddings endpoint."""
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LITELLM_URL}/embeddings",
|
||||||
|
json={"model": EMBED_MODEL, "input": text},
|
||||||
|
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["data"][0]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
async def _init_db():
|
||||||
|
"""Create pgvector extension and memories table if not exists."""
|
||||||
|
global pool
|
||||||
|
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
await conn.execute(f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
fact TEXT NOT NULL,
|
||||||
|
source_room TEXT DEFAULT '',
|
||||||
|
created_at DOUBLE PRECISION NOT NULL,
|
||||||
|
embedding vector({EMBED_DIMS}) NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)
|
||||||
|
""")
|
||||||
|
await conn.execute(f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_embedding
|
||||||
|
ON memories USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 10)
|
||||||
|
""")
|
||||||
|
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup():
|
||||||
|
await _init_db()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown():
|
||||||
|
if pool:
|
||||||
|
await pool.close()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
if pool:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||||
|
return {"status": "ok", "total_memories": count}
|
||||||
|
return {"status": "no_db"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/memories/store")
|
||||||
|
async def store_memory(req: StoreRequest):
|
||||||
|
"""Embed fact, deduplicate by cosine similarity, insert."""
|
||||||
|
if not req.fact.strip():
|
||||||
|
raise HTTPException(400, "Empty fact")
|
||||||
|
|
||||||
|
embedding = await _embed(req.fact)
|
||||||
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
# Check for duplicates (cosine similarity > threshold)
|
||||||
|
dup = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT id FROM memories
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND 1 - (embedding <=> $2::vector) > $3
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
req.user_id, vec_literal, DEDUP_THRESHOLD,
|
||||||
|
)
|
||||||
|
if dup:
|
||||||
|
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
|
||||||
|
return {"stored": False, "reason": "duplicate"}
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
|
||||||
|
VALUES ($1, $2, $3, $4, $5::vector)
|
||||||
|
""",
|
||||||
|
req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal,
|
||||||
|
)
|
||||||
|
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
|
||||||
|
return {"stored": True}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/memories/query")
|
||||||
|
async def query_memories(req: QueryRequest):
|
||||||
|
"""Embed query, return top-K similar facts for user."""
|
||||||
|
embedding = await _embed(req.query)
|
||||||
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT fact, source_room, created_at,
|
||||||
|
1 - (embedding <=> $1::vector) AS similarity
|
||||||
|
FROM memories
|
||||||
|
WHERE user_id = $2
|
||||||
|
ORDER BY embedding <=> $1::vector
|
||||||
|
LIMIT $3
|
||||||
|
""",
|
||||||
|
vec_literal, req.user_id, req.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"fact": r["fact"],
|
||||||
|
"source_room": r["source_room"],
|
||||||
|
"created_at": r["created_at"],
|
||||||
|
"similarity": float(r["similarity"]),
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/memories/{user_id}")
|
||||||
|
async def delete_user_memories(user_id: str):
|
||||||
|
"""GDPR delete — remove all memories for a user."""
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
|
||||||
|
count = int(result.split()[-1])
|
||||||
|
logger.info("Deleted %d memories for %s", count, user_id)
|
||||||
|
return {"deleted": count}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/memories/{user_id}")
|
||||||
|
async def list_user_memories(user_id: str):
|
||||||
|
"""List all memories for a user (for UI/debug)."""
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT fact, source_room, created_at
|
||||||
|
FROM memories
|
||||||
|
WHERE user_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"count": len(rows),
|
||||||
|
"memories": [
|
||||||
|
{"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
108
memory-service/migrate_json.py
Normal file
108
memory-service/migrate_json.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""One-time migration: read JSON memory files, embed each fact, insert into pgvector."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("migrate")
|
||||||
|
|
||||||
|
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
|
||||||
|
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||||
|
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||||
|
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
||||||
|
MEMORIES_DIR = os.environ.get("MEMORIES_DIR", "/data/memories")
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LITELLM_URL}/embeddings",
|
||||||
|
json={"model": EMBED_MODEL, "input": text},
|
||||||
|
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["data"][0]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
if not os.path.isdir(MEMORIES_DIR):
|
||||||
|
logger.error("MEMORIES_DIR %s does not exist", MEMORIES_DIR)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
json_files = [f for f in os.listdir(MEMORIES_DIR) if f.endswith(".json")]
|
||||||
|
if not json_files:
|
||||||
|
logger.info("No JSON memory files found in %s", MEMORIES_DIR)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Found %d memory files to migrate", len(json_files))
|
||||||
|
|
||||||
|
pool = await asyncpg.create_pool(DB_DSN, min_size=1, max_size=5)
|
||||||
|
|
||||||
|
total_migrated = 0
|
||||||
|
total_skipped = 0
|
||||||
|
|
||||||
|
for filename in json_files:
|
||||||
|
filepath = os.path.join(MEMORIES_DIR, filename)
|
||||||
|
try:
|
||||||
|
with open(filepath) as f:
|
||||||
|
memories = json.load(f)
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
logger.warning("Skipping %s: %s", filename, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not memories:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The filename is a hash of the user_id — we need to find the user_id
|
||||||
|
# from the fact entries or use the hash as identifier.
|
||||||
|
# Since JSON files are named by sha256(user_id)[:16].json, we can't
|
||||||
|
# reverse the hash. We'll need to scan bot-data for user_keys.json
|
||||||
|
# to build a mapping, or just use the hash as user_id placeholder.
|
||||||
|
#
|
||||||
|
# Better approach: read all facts and check if any contain user identity.
|
||||||
|
# For now, use the filename hash as a temporary user_id marker.
|
||||||
|
# The bot will re-associate on next interaction.
|
||||||
|
user_hash = filename.replace(".json", "")
|
||||||
|
|
||||||
|
for mem in memories:
|
||||||
|
fact = mem.get("fact", "").strip()
|
||||||
|
if not fact:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding = await embed(fact)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Embedding failed for fact '%s': %s", fact[:50], e)
|
||||||
|
total_skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
created_at = mem.get("created", time.time())
|
||||||
|
source_room = mem.get("source_room", "")
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
|
||||||
|
VALUES ($1, $2, $3, $4, $5::vector)
|
||||||
|
""",
|
||||||
|
user_hash, fact, source_room, created_at, vec_literal,
|
||||||
|
)
|
||||||
|
total_migrated += 1
|
||||||
|
|
||||||
|
logger.info("Migrated %s: %d facts", filename, len(memories))
|
||||||
|
|
||||||
|
await pool.close()
|
||||||
|
logger.info("Migration complete: %d migrated, %d skipped", total_migrated, total_skipped)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
5
memory-service/requirements.txt
Normal file
5
memory-service/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
fastapi>=0.115,<1.0
|
||||||
|
uvicorn>=0.34,<1.0
|
||||||
|
asyncpg>=0.30,<1.0
|
||||||
|
pgvector>=0.3,<1.0
|
||||||
|
httpx>=0.27,<1.0
|
||||||
Reference in New Issue
Block a user