From fb54ac2beaebbf6a33e20e060ef756403c803b32 Mon Sep 17 00:00:00 2001 From: Christian Gick Date: Sun, 1 Mar 2026 07:48:19 +0200 Subject: [PATCH] feat(MAT-13): Add conversation chunk RAG for Matrix chat history Add semantic search over past conversations alongside existing memory facts. New conversation_chunks table stores user-assistant exchanges with LLM-generated summaries embedded for retrieval. Bot queries chunks on each message and injects relevant past conversations into the system prompt. New exchanges are indexed automatically after each bot response. Memory-service: /chunks/store, /chunks/query, /chunks/bulk-store endpoints Bot: chunk query + formatting, live indexing via asyncio.gather with memory extraction Co-Authored-By: Claude Opus 4.6 --- bot.py | 102 +++++++++++++++++++++-- memory-service/main.py | 183 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 277 insertions(+), 8 deletions(-) diff --git a/bot.py b/bot.py index 7bffac8..37f9c6a 100644 --- a/bot.py +++ b/bot.py @@ -457,6 +457,39 @@ class MemoryClient: logger.warning("Memory list failed", exc_info=True) return [] + async def store_chunk(self, user_id: str, room_id: str, chunk_text: str, + summary: str, source_event_id: str = "", original_ts: float = 0.0): + if not self.enabled: + return + try: + async with httpx.AsyncClient(timeout=15.0) as client: + await client.post( + f"{self.base_url}/chunks/store", + json={ + "user_id": user_id, "room_id": room_id, + "chunk_text": chunk_text, "summary": summary, + "source_event_id": source_event_id, "original_ts": original_ts, + }, + ) + except Exception: + logger.warning("Chunk store failed", exc_info=True) + + async def query_chunks(self, query: str, user_id: str = "", room_id: str = "", + top_k: int = 5) -> list[dict]: + if not self.enabled: + return [] + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{self.base_url}/chunks/query", + json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k}, + ) + resp.raise_for_status() + return resp.json().get("results", []) + except Exception: + logger.warning("Chunk query failed", exc_info=True) + return [] + class AtlassianClient: """Fetches per-user Atlassian tokens from the portal and calls Atlassian REST APIs.""" @@ -1173,6 +1206,52 @@ class Bot: facts = [m["fact"] for m in memories] return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts) + @staticmethod + def _format_chunks(chunks: list[dict]) -> str: + """Format conversation chunk results as a system prompt section.""" + if not chunks: + return "" + parts = ["Relevant past conversations:"] + for c in chunks: + ts = c.get("original_ts", 0) + date_str = time.strftime("%Y-%m-%d", time.gmtime(ts)) if ts else "unknown" + summary = c.get("summary", "") + text = c.get("chunk_text", "") + # Truncate chunk text to ~500 chars for context window efficiency + if len(text) > 500: + text = text[:500] + "..." + parts.append(f"\n### {summary} ({date_str})\n{text}") + return "\n".join(parts) + + async def _store_conversation_chunk(self, user_message: str, ai_reply: str, + sender: str, room_id: str): + """Store a user-assistant exchange as a conversation chunk for RAG.""" + if not self.llm or not self.memory.enabled: + return + chunk_text = f"User: {user_message}\nAssistant: {ai_reply}" + try: + resp = await self.llm.chat.completions.create( + model="claude-haiku", + messages=[ + {"role": "system", "content": ( + "Summarize this conversation exchange in 1-2 sentences for search indexing. " + "Focus on the topic and key information discussed. Be concise. " + "Write the summary in the same language as the conversation." + )}, + {"role": "user", "content": chunk_text[:2000]}, + ], + max_tokens=100, + ) + summary = resp.choices[0].message.content.strip() + except Exception: + logger.warning("Chunk summarization failed, using truncated message", exc_info=True) + summary = user_message[:200] + + await self.memory.store_chunk( + user_id=sender, room_id=room_id, chunk_text=chunk_text, + summary=summary, original_ts=time.time(), + ) + async def _extract_and_store_memories(self, user_message: str, ai_reply: str, existing_facts: list[str], model: str, sender: str, room_id: str): @@ -2092,6 +2171,10 @@ class Bot: memories = await self.memory.query(sender, user_message, top_k=10) if sender else [] memory_context = self._format_memories(memories) + # Query relevant conversation chunks (RAG over chat history) + chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5) + chunk_context = self._format_chunks(chunks) + # Include room document context (PDFs, Confluence pages, images uploaded to room) room_doc_context = "" room_docs = [e for e in self._room_document_context.get(room.room_id, []) @@ -2114,6 +2197,8 @@ class Bot: messages = [{"role": "system", "content": SYSTEM_PROMPT}] if memory_context: messages.append({"role": "system", "content": memory_context}) + if chunk_context: + messages.append({"role": "system", "content": chunk_context}) if doc_context: messages.append({"role": "system", "content": doc_context}) if room_doc_context: @@ -2181,20 +2266,25 @@ class Bot: if reply: await self._send_text(room.room_id, reply) - # Extract and store new memories (after reply sent, with timeout) + # Extract and store new memories + conversation chunk (after reply sent) if sender and reply: existing_facts = [m["fact"] for m in memories] try: await asyncio.wait_for( - self._extract_and_store_memories( - user_message, reply, existing_facts, model, sender, room.room_id + asyncio.gather( + self._extract_and_store_memories( + user_message, reply, existing_facts, model, sender, room.room_id + ), + self._store_conversation_chunk( + user_message, reply, sender, room.room_id + ), ), - timeout=15.0, + timeout=20.0, ) except asyncio.TimeoutError: - logger.warning("Memory extraction timed out for %s", sender) + logger.warning("Memory/chunk extraction timed out for %s", sender) except Exception: - logger.warning("Memory save failed", exc_info=True) + logger.warning("Memory/chunk save failed", exc_info=True) # Auto-rename: only for group rooms with explicit opt-in (not DMs) if room.room_id in self.auto_rename_rooms: diff --git a/memory-service/main.py b/memory-service/main.py index 9b107ac..f550488 100644 --- a/memory-service/main.py +++ b/memory-service/main.py @@ -33,6 +33,26 @@ class QueryRequest(BaseModel): top_k: int = 10 +class ChunkStoreRequest(BaseModel): + user_id: str + room_id: str + chunk_text: str + summary: str + source_event_id: str = "" + original_ts: float = 0.0 + + +class ChunkQueryRequest(BaseModel): + user_id: str = "" + room_id: str = "" + query: str + top_k: int = 5 + + +class ChunkBulkStoreRequest(BaseModel): + chunks: list[ChunkStoreRequest] + + async def _embed(text: str) -> list[float]: """Get embedding vector from LiteLLM /embeddings endpoint.""" async with httpx.AsyncClient(timeout=30.0) as client: @@ -45,6 +65,19 @@ async def _embed(text: str) -> list[float]: return resp.json()["data"][0]["embedding"] +async def _embed_batch(texts: list[str]) -> list[list[float]]: + """Get embedding vectors for a batch of texts.""" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{LITELLM_URL}/embeddings", + json={"model": EMBED_MODEL, "input": texts}, + headers={"Authorization": f"Bearer {LITELLM_KEY}"}, + ) + resp.raise_for_status() + data = resp.json()["data"] + return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] + + async def _init_db(): """Create pgvector extension and memories table if not exists.""" global pool @@ -69,6 +102,25 @@ async def _init_db(): ON memories USING ivfflat (embedding vector_cosine_ops) WITH (lists = 10) """) + # Conversation chunks table for RAG over chat history + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS conversation_chunks ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + chunk_text TEXT NOT NULL, + summary TEXT NOT NULL, + source_event_id TEXT DEFAULT '', + original_ts DOUBLE PRECISION NOT NULL, + embedding vector({EMBED_DIMS}) NOT NULL + ) + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id) + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) + """) logger.info("Database initialized (dims=%d)", EMBED_DIMS) @@ -87,8 +139,9 @@ async def shutdown(): async def health(): if pool: async with pool.acquire() as conn: - count = await conn.fetchval("SELECT count(*) FROM memories") - return {"status": "ok", "total_memories": count} + mem_count = await conn.fetchval("SELECT count(*) FROM memories") + chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks") + return {"status": "ok", "total_memories": mem_count, "total_chunks": chunk_count} return {"status": "no_db"} @@ -189,3 +242,129 @@ async def list_user_memories(user_id: str): for r in rows ], } + + +# --- Conversation Chunks --- + + +@app.post("/chunks/store") +async def store_chunk(req: ChunkStoreRequest): + """Store a conversation chunk with its summary embedding.""" + if not req.summary.strip(): + raise HTTPException(400, "Empty summary") + + embedding = await _embed(req.summary) + vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" + ts = req.original_ts or time.time() + + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO conversation_chunks + (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7::vector) + """, + req.user_id, req.room_id, req.chunk_text, req.summary, + req.source_event_id, ts, vec_literal, + ) + logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60]) + return {"stored": True} + + +@app.post("/chunks/query") +async def query_chunks(req: ChunkQueryRequest): + """Semantic search over conversation chunks. Filter by user_id and/or room_id.""" + embedding = await _embed(req.query) + vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" + + # Build WHERE clause dynamically based on provided filters + conditions = [] + params: list = [vec_literal] + idx = 2 + + if req.user_id: + conditions.append(f"user_id = ${idx}") + params.append(req.user_id) + idx += 1 + if req.room_id: + conditions.append(f"room_id = ${idx}") + params.append(req.room_id) + idx += 1 + + where = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(req.top_k) + + async with pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id, + 1 - (embedding <=> $1::vector) AS similarity + FROM conversation_chunks + {where} + ORDER BY embedding <=> $1::vector + LIMIT ${idx} + """, + *params, + ) + + results = [ + { + "chunk_text": r["chunk_text"], + "summary": r["summary"], + "room_id": r["room_id"], + "user_id": r["user_id"], + "original_ts": r["original_ts"], + "source_event_id": r["source_event_id"], + "similarity": float(r["similarity"]), + } + for r in rows + ] + return {"results": results} + + +@app.post("/chunks/bulk-store") +async def bulk_store_chunks(req: ChunkBulkStoreRequest): + """Batch store conversation chunks. Embeds summaries in batches of 20.""" + if not req.chunks: + return {"stored": 0} + + stored = 0 + batch_size = 20 + + for i in range(0, len(req.chunks), batch_size): + batch = req.chunks[i:i + batch_size] + summaries = [c.summary.strip() for c in batch] + + try: + embeddings = await _embed_batch(summaries) + except Exception: + logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True) + continue + + async with pool.acquire() as conn: + for chunk, embedding in zip(batch, embeddings): + vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" + ts = chunk.original_ts or time.time() + await conn.execute( + """ + INSERT INTO conversation_chunks + (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7::vector) + """, + chunk.user_id, chunk.room_id, chunk.chunk_text, chunk.summary, + chunk.source_event_id, ts, vec_literal, + ) + stored += 1 + + logger.info("Bulk stored %d chunks", stored) + return {"stored": stored} + + +@app.get("/chunks/{user_id}/count") +async def count_user_chunks(user_id: str): + """Count conversation chunks for a user.""" + async with pool.acquire() as conn: + count = await conn.fetchval( + "SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id, + ) + return {"user_id": user_id, "count": count}