feat(MAT-13): Add conversation chunk RAG for Matrix chat history
Add semantic search over past conversations alongside existing memory facts. New conversation_chunks table stores user-assistant exchanges with LLM-generated summaries embedded for retrieval. Bot queries chunks on each message and injects relevant past conversations into the system prompt. New exchanges are indexed automatically after each bot response. Memory-service: /chunks/store, /chunks/query, /chunks/bulk-store endpoints Bot: chunk query + formatting, live indexing via asyncio.gather with memory extraction Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -33,6 +33,26 @@ class QueryRequest(BaseModel):
|
||||
top_k: int = 10
|
||||
|
||||
|
||||
class ChunkStoreRequest(BaseModel):
|
||||
user_id: str
|
||||
room_id: str
|
||||
chunk_text: str
|
||||
summary: str
|
||||
source_event_id: str = ""
|
||||
original_ts: float = 0.0
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
user_id: str = ""
|
||||
room_id: str = ""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class ChunkBulkStoreRequest(BaseModel):
|
||||
chunks: list[ChunkStoreRequest]
|
||||
|
||||
|
||||
async def _embed(text: str) -> list[float]:
|
||||
"""Get embedding vector from LiteLLM /embeddings endpoint."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
@@ -45,6 +65,19 @@ async def _embed(text: str) -> list[float]:
|
||||
return resp.json()["data"][0]["embedding"]
|
||||
|
||||
|
||||
async def _embed_batch(texts: list[str]) -> list[list[float]]:
|
||||
"""Get embedding vectors for a batch of texts."""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LITELLM_URL}/embeddings",
|
||||
json={"model": EMBED_MODEL, "input": texts},
|
||||
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()["data"]
|
||||
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
|
||||
|
||||
|
||||
async def _init_db():
|
||||
"""Create pgvector extension and memories table if not exists."""
|
||||
global pool
|
||||
@@ -69,6 +102,25 @@ async def _init_db():
|
||||
ON memories USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 10)
|
||||
""")
|
||||
# Conversation chunks table for RAG over chat history
|
||||
await conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS conversation_chunks (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
chunk_text TEXT NOT NULL,
|
||||
summary TEXT NOT NULL,
|
||||
source_event_id TEXT DEFAULT '',
|
||||
original_ts DOUBLE PRECISION NOT NULL,
|
||||
embedding vector({EMBED_DIMS}) NOT NULL
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||
""")
|
||||
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
|
||||
|
||||
|
||||
@@ -87,8 +139,9 @@ async def shutdown():
|
||||
async def health():
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||
return {"status": "ok", "total_memories": count}
|
||||
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
||||
return {"status": "ok", "total_memories": mem_count, "total_chunks": chunk_count}
|
||||
return {"status": "no_db"}
|
||||
|
||||
|
||||
@@ -189,3 +242,129 @@ async def list_user_memories(user_id: str):
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# --- Conversation Chunks ---
|
||||
|
||||
|
||||
@app.post("/chunks/store")
|
||||
async def store_chunk(req: ChunkStoreRequest):
|
||||
"""Store a conversation chunk with its summary embedding."""
|
||||
if not req.summary.strip():
|
||||
raise HTTPException(400, "Empty summary")
|
||||
|
||||
embedding = await _embed(req.summary)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
ts = req.original_ts or time.time()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO conversation_chunks
|
||||
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||
""",
|
||||
req.user_id, req.room_id, req.chunk_text, req.summary,
|
||||
req.source_event_id, ts, vec_literal,
|
||||
)
|
||||
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
|
||||
return {"stored": True}
|
||||
|
||||
|
||||
@app.post("/chunks/query")
|
||||
async def query_chunks(req: ChunkQueryRequest):
|
||||
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
# Build WHERE clause dynamically based on provided filters
|
||||
conditions = []
|
||||
params: list = [vec_literal]
|
||||
idx = 2
|
||||
|
||||
if req.user_id:
|
||||
conditions.append(f"user_id = ${idx}")
|
||||
params.append(req.user_id)
|
||||
idx += 1
|
||||
if req.room_id:
|
||||
conditions.append(f"room_id = ${idx}")
|
||||
params.append(req.room_id)
|
||||
idx += 1
|
||||
|
||||
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
params.append(req.top_k)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
|
||||
1 - (embedding <=> $1::vector) AS similarity
|
||||
FROM conversation_chunks
|
||||
{where}
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT ${idx}
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
results = [
|
||||
{
|
||||
"chunk_text": r["chunk_text"],
|
||||
"summary": r["summary"],
|
||||
"room_id": r["room_id"],
|
||||
"user_id": r["user_id"],
|
||||
"original_ts": r["original_ts"],
|
||||
"source_event_id": r["source_event_id"],
|
||||
"similarity": float(r["similarity"]),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return {"results": results}
|
||||
|
||||
|
||||
@app.post("/chunks/bulk-store")
|
||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
||||
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
||||
if not req.chunks:
|
||||
return {"stored": 0}
|
||||
|
||||
stored = 0
|
||||
batch_size = 20
|
||||
|
||||
for i in range(0, len(req.chunks), batch_size):
|
||||
batch = req.chunks[i:i + batch_size]
|
||||
summaries = [c.summary.strip() for c in batch]
|
||||
|
||||
try:
|
||||
embeddings = await _embed_batch(summaries)
|
||||
except Exception:
|
||||
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
|
||||
continue
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for chunk, embedding in zip(batch, embeddings):
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
ts = chunk.original_ts or time.time()
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO conversation_chunks
|
||||
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||
""",
|
||||
chunk.user_id, chunk.room_id, chunk.chunk_text, chunk.summary,
|
||||
chunk.source_event_id, ts, vec_literal,
|
||||
)
|
||||
stored += 1
|
||||
|
||||
logger.info("Bulk stored %d chunks", stored)
|
||||
return {"stored": stored}
|
||||
|
||||
|
||||
@app.get("/chunks/{user_id}/count")
|
||||
async def count_user_chunks(user_id: str):
|
||||
"""Count conversation chunks for a user."""
|
||||
async with pool.acquire() as conn:
|
||||
count = await conn.fetchval(
|
||||
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
|
||||
)
|
||||
return {"user_id": user_id, "count": count}
|
||||
|
||||
Reference in New Issue
Block a user