feat(MAT-13): Add conversation chunk RAG for Matrix chat history
Add semantic search over past conversations alongside existing memory facts. New conversation_chunks table stores user-assistant exchanges with LLM-generated summaries embedded for retrieval. Bot queries chunks on each message and injects relevant past conversations into the system prompt. New exchanges are indexed automatically after each bot response. Memory-service: /chunks/store, /chunks/query, /chunks/bulk-store endpoints Bot: chunk query + formatting, live indexing via asyncio.gather with memory extraction Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
98
bot.py
98
bot.py
@@ -457,6 +457,39 @@ class MemoryClient:
|
|||||||
logger.warning("Memory list failed", exc_info=True)
|
logger.warning("Memory list failed", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def store_chunk(self, user_id: str, room_id: str, chunk_text: str,
|
||||||
|
summary: str, source_event_id: str = "", original_ts: float = 0.0):
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||||
|
await client.post(
|
||||||
|
f"{self.base_url}/chunks/store",
|
||||||
|
json={
|
||||||
|
"user_id": user_id, "room_id": room_id,
|
||||||
|
"chunk_text": chunk_text, "summary": summary,
|
||||||
|
"source_event_id": source_event_id, "original_ts": original_ts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Chunk store failed", exc_info=True)
|
||||||
|
|
||||||
|
async def query_chunks(self, query: str, user_id: str = "", room_id: str = "",
|
||||||
|
top_k: int = 5) -> list[dict]:
|
||||||
|
if not self.enabled:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{self.base_url}/chunks/query",
|
||||||
|
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("results", [])
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Chunk query failed", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class AtlassianClient:
|
class AtlassianClient:
|
||||||
"""Fetches per-user Atlassian tokens from the portal and calls Atlassian REST APIs."""
|
"""Fetches per-user Atlassian tokens from the portal and calls Atlassian REST APIs."""
|
||||||
@@ -1173,6 +1206,52 @@ class Bot:
|
|||||||
facts = [m["fact"] for m in memories]
|
facts = [m["fact"] for m in memories]
|
||||||
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
|
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_chunks(chunks: list[dict]) -> str:
|
||||||
|
"""Format conversation chunk results as a system prompt section."""
|
||||||
|
if not chunks:
|
||||||
|
return ""
|
||||||
|
parts = ["Relevant past conversations:"]
|
||||||
|
for c in chunks:
|
||||||
|
ts = c.get("original_ts", 0)
|
||||||
|
date_str = time.strftime("%Y-%m-%d", time.gmtime(ts)) if ts else "unknown"
|
||||||
|
summary = c.get("summary", "")
|
||||||
|
text = c.get("chunk_text", "")
|
||||||
|
# Truncate chunk text to ~500 chars for context window efficiency
|
||||||
|
if len(text) > 500:
|
||||||
|
text = text[:500] + "..."
|
||||||
|
parts.append(f"\n### {summary} ({date_str})\n{text}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
async def _store_conversation_chunk(self, user_message: str, ai_reply: str,
|
||||||
|
sender: str, room_id: str):
|
||||||
|
"""Store a user-assistant exchange as a conversation chunk for RAG."""
|
||||||
|
if not self.llm or not self.memory.enabled:
|
||||||
|
return
|
||||||
|
chunk_text = f"User: {user_message}\nAssistant: {ai_reply}"
|
||||||
|
try:
|
||||||
|
resp = await self.llm.chat.completions.create(
|
||||||
|
model="claude-haiku",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": (
|
||||||
|
"Summarize this conversation exchange in 1-2 sentences for search indexing. "
|
||||||
|
"Focus on the topic and key information discussed. Be concise. "
|
||||||
|
"Write the summary in the same language as the conversation."
|
||||||
|
)},
|
||||||
|
{"role": "user", "content": chunk_text[:2000]},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
summary = resp.choices[0].message.content.strip()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Chunk summarization failed, using truncated message", exc_info=True)
|
||||||
|
summary = user_message[:200]
|
||||||
|
|
||||||
|
await self.memory.store_chunk(
|
||||||
|
user_id=sender, room_id=room_id, chunk_text=chunk_text,
|
||||||
|
summary=summary, original_ts=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
async def _extract_and_store_memories(self, user_message: str, ai_reply: str,
|
async def _extract_and_store_memories(self, user_message: str, ai_reply: str,
|
||||||
existing_facts: list[str], model: str,
|
existing_facts: list[str], model: str,
|
||||||
sender: str, room_id: str):
|
sender: str, room_id: str):
|
||||||
@@ -2092,6 +2171,10 @@ class Bot:
|
|||||||
memories = await self.memory.query(sender, user_message, top_k=10) if sender else []
|
memories = await self.memory.query(sender, user_message, top_k=10) if sender else []
|
||||||
memory_context = self._format_memories(memories)
|
memory_context = self._format_memories(memories)
|
||||||
|
|
||||||
|
# Query relevant conversation chunks (RAG over chat history)
|
||||||
|
chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5)
|
||||||
|
chunk_context = self._format_chunks(chunks)
|
||||||
|
|
||||||
# Include room document context (PDFs, Confluence pages, images uploaded to room)
|
# Include room document context (PDFs, Confluence pages, images uploaded to room)
|
||||||
room_doc_context = ""
|
room_doc_context = ""
|
||||||
room_docs = [e for e in self._room_document_context.get(room.room_id, [])
|
room_docs = [e for e in self._room_document_context.get(room.room_id, [])
|
||||||
@@ -2114,6 +2197,8 @@ class Bot:
|
|||||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||||
if memory_context:
|
if memory_context:
|
||||||
messages.append({"role": "system", "content": memory_context})
|
messages.append({"role": "system", "content": memory_context})
|
||||||
|
if chunk_context:
|
||||||
|
messages.append({"role": "system", "content": chunk_context})
|
||||||
if doc_context:
|
if doc_context:
|
||||||
messages.append({"role": "system", "content": doc_context})
|
messages.append({"role": "system", "content": doc_context})
|
||||||
if room_doc_context:
|
if room_doc_context:
|
||||||
@@ -2181,20 +2266,25 @@ class Bot:
|
|||||||
if reply:
|
if reply:
|
||||||
await self._send_text(room.room_id, reply)
|
await self._send_text(room.room_id, reply)
|
||||||
|
|
||||||
# Extract and store new memories (after reply sent, with timeout)
|
# Extract and store new memories + conversation chunk (after reply sent)
|
||||||
if sender and reply:
|
if sender and reply:
|
||||||
existing_facts = [m["fact"] for m in memories]
|
existing_facts = [m["fact"] for m in memories]
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(
|
||||||
self._extract_and_store_memories(
|
self._extract_and_store_memories(
|
||||||
user_message, reply, existing_facts, model, sender, room.room_id
|
user_message, reply, existing_facts, model, sender, room.room_id
|
||||||
),
|
),
|
||||||
timeout=15.0,
|
self._store_conversation_chunk(
|
||||||
|
user_message, reply, sender, room.room_id
|
||||||
|
),
|
||||||
|
),
|
||||||
|
timeout=20.0,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning("Memory extraction timed out for %s", sender)
|
logger.warning("Memory/chunk extraction timed out for %s", sender)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Memory save failed", exc_info=True)
|
logger.warning("Memory/chunk save failed", exc_info=True)
|
||||||
|
|
||||||
# Auto-rename: only for group rooms with explicit opt-in (not DMs)
|
# Auto-rename: only for group rooms with explicit opt-in (not DMs)
|
||||||
if room.room_id in self.auto_rename_rooms:
|
if room.room_id in self.auto_rename_rooms:
|
||||||
|
|||||||
@@ -33,6 +33,26 @@ class QueryRequest(BaseModel):
|
|||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkStoreRequest(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
room_id: str
|
||||||
|
chunk_text: str
|
||||||
|
summary: str
|
||||||
|
source_event_id: str = ""
|
||||||
|
original_ts: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkQueryRequest(BaseModel):
|
||||||
|
user_id: str = ""
|
||||||
|
room_id: str = ""
|
||||||
|
query: str
|
||||||
|
top_k: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkBulkStoreRequest(BaseModel):
|
||||||
|
chunks: list[ChunkStoreRequest]
|
||||||
|
|
||||||
|
|
||||||
async def _embed(text: str) -> list[float]:
|
async def _embed(text: str) -> list[float]:
|
||||||
"""Get embedding vector from LiteLLM /embeddings endpoint."""
|
"""Get embedding vector from LiteLLM /embeddings endpoint."""
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
@@ -45,6 +65,19 @@ async def _embed(text: str) -> list[float]:
|
|||||||
return resp.json()["data"][0]["embedding"]
|
return resp.json()["data"][0]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
async def _embed_batch(texts: list[str]) -> list[list[float]]:
|
||||||
|
"""Get embedding vectors for a batch of texts."""
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LITELLM_URL}/embeddings",
|
||||||
|
json={"model": EMBED_MODEL, "input": texts},
|
||||||
|
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()["data"]
|
||||||
|
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
|
||||||
|
|
||||||
|
|
||||||
async def _init_db():
|
async def _init_db():
|
||||||
"""Create pgvector extension and memories table if not exists."""
|
"""Create pgvector extension and memories table if not exists."""
|
||||||
global pool
|
global pool
|
||||||
@@ -69,6 +102,25 @@ async def _init_db():
|
|||||||
ON memories USING ivfflat (embedding vector_cosine_ops)
|
ON memories USING ivfflat (embedding vector_cosine_ops)
|
||||||
WITH (lists = 10)
|
WITH (lists = 10)
|
||||||
""")
|
""")
|
||||||
|
# Conversation chunks table for RAG over chat history
|
||||||
|
await conn.execute(f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS conversation_chunks (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
chunk_text TEXT NOT NULL,
|
||||||
|
summary TEXT NOT NULL,
|
||||||
|
source_event_id TEXT DEFAULT '',
|
||||||
|
original_ts DOUBLE PRECISION NOT NULL,
|
||||||
|
embedding vector({EMBED_DIMS}) NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id)
|
||||||
|
""")
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||||
|
""")
|
||||||
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
|
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
|
||||||
|
|
||||||
|
|
||||||
@@ -87,8 +139,9 @@ async def shutdown():
|
|||||||
async def health():
|
async def health():
|
||||||
if pool:
|
if pool:
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
count = await conn.fetchval("SELECT count(*) FROM memories")
|
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||||
return {"status": "ok", "total_memories": count}
|
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
||||||
|
return {"status": "ok", "total_memories": mem_count, "total_chunks": chunk_count}
|
||||||
return {"status": "no_db"}
|
return {"status": "no_db"}
|
||||||
|
|
||||||
|
|
||||||
@@ -189,3 +242,129 @@ async def list_user_memories(user_id: str):
|
|||||||
for r in rows
|
for r in rows
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Conversation Chunks ---
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chunks/store")
|
||||||
|
async def store_chunk(req: ChunkStoreRequest):
|
||||||
|
"""Store a conversation chunk with its summary embedding."""
|
||||||
|
if not req.summary.strip():
|
||||||
|
raise HTTPException(400, "Empty summary")
|
||||||
|
|
||||||
|
embedding = await _embed(req.summary)
|
||||||
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
ts = req.original_ts or time.time()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_chunks
|
||||||
|
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||||
|
""",
|
||||||
|
req.user_id, req.room_id, req.chunk_text, req.summary,
|
||||||
|
req.source_event_id, ts, vec_literal,
|
||||||
|
)
|
||||||
|
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
|
||||||
|
return {"stored": True}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chunks/query")
|
||||||
|
async def query_chunks(req: ChunkQueryRequest):
|
||||||
|
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
||||||
|
embedding = await _embed(req.query)
|
||||||
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
|
||||||
|
# Build WHERE clause dynamically based on provided filters
|
||||||
|
conditions = []
|
||||||
|
params: list = [vec_literal]
|
||||||
|
idx = 2
|
||||||
|
|
||||||
|
if req.user_id:
|
||||||
|
conditions.append(f"user_id = ${idx}")
|
||||||
|
params.append(req.user_id)
|
||||||
|
idx += 1
|
||||||
|
if req.room_id:
|
||||||
|
conditions.append(f"room_id = ${idx}")
|
||||||
|
params.append(req.room_id)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||||
|
params.append(req.top_k)
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
f"""
|
||||||
|
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
|
||||||
|
1 - (embedding <=> $1::vector) AS similarity
|
||||||
|
FROM conversation_chunks
|
||||||
|
{where}
|
||||||
|
ORDER BY embedding <=> $1::vector
|
||||||
|
LIMIT ${idx}
|
||||||
|
""",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"chunk_text": r["chunk_text"],
|
||||||
|
"summary": r["summary"],
|
||||||
|
"room_id": r["room_id"],
|
||||||
|
"user_id": r["user_id"],
|
||||||
|
"original_ts": r["original_ts"],
|
||||||
|
"source_event_id": r["source_event_id"],
|
||||||
|
"similarity": float(r["similarity"]),
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chunks/bulk-store")
|
||||||
|
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
||||||
|
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
||||||
|
if not req.chunks:
|
||||||
|
return {"stored": 0}
|
||||||
|
|
||||||
|
stored = 0
|
||||||
|
batch_size = 20
|
||||||
|
|
||||||
|
for i in range(0, len(req.chunks), batch_size):
|
||||||
|
batch = req.chunks[i:i + batch_size]
|
||||||
|
summaries = [c.summary.strip() for c in batch]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = await _embed_batch(summaries)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
for chunk, embedding in zip(batch, embeddings):
|
||||||
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
ts = chunk.original_ts or time.time()
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO conversation_chunks
|
||||||
|
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||||
|
""",
|
||||||
|
chunk.user_id, chunk.room_id, chunk.chunk_text, chunk.summary,
|
||||||
|
chunk.source_event_id, ts, vec_literal,
|
||||||
|
)
|
||||||
|
stored += 1
|
||||||
|
|
||||||
|
logger.info("Bulk stored %d chunks", stored)
|
||||||
|
return {"stored": stored}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/chunks/{user_id}/count")
|
||||||
|
async def count_user_chunks(user_id: str):
|
||||||
|
"""Count conversation chunks for a user."""
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
count = await conn.fetchval(
|
||||||
|
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
|
||||||
|
)
|
||||||
|
return {"user_id": user_id, "count": count}
|
||||||
|
|||||||
Reference in New Issue
Block a user