fix: harden Matrix ecosystem — pool recovery, parallel queries, voice persistence
- Memory service: asyncpg pool auto-reconnect on connection loss, IVFFlat lists 10→100 - Bot: parallel RAG/memory/chunk queries (asyncio.gather), parallel tool execution - Bot: skip memory extraction for trivial messages (<20 chars, no personal facts) - Bot: persist voice call transcripts as searchable conversation chunks - RAG: JSON parse safety in AI metadata, embedding_status tracking, fetch timeouts - Drive sync: token refresh mutex to prevent race conditions, fetch timeouts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,7 @@ MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||
app = FastAPI(title="Memory Service")
|
||||
pool: asyncpg.Pool | None = None
|
||||
owner_pool: asyncpg.Pool | None = None
|
||||
_pool_healthy = True
|
||||
|
||||
|
||||
async def verify_token(authorization: str | None = Header(None)):
|
||||
@@ -161,9 +162,36 @@ async def _set_rls_user(conn, user_id: str):
|
||||
await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id)
|
||||
|
||||
|
||||
async def _ensure_pool():
|
||||
"""Recreate the connection pool if it was lost."""
|
||||
global pool, owner_pool, _pool_healthy
|
||||
if pool and _pool_healthy:
|
||||
return
|
||||
logger.warning("Reconnecting asyncpg pools (healthy=%s, pool=%s)", _pool_healthy, pool is not None)
|
||||
try:
|
||||
if pool:
|
||||
try:
|
||||
await pool.close()
|
||||
except Exception:
|
||||
pass
|
||||
if owner_pool:
|
||||
try:
|
||||
await owner_pool.close()
|
||||
except Exception:
|
||||
pass
|
||||
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
||||
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
||||
_pool_healthy = True
|
||||
logger.info("asyncpg pools reconnected successfully")
|
||||
except Exception:
|
||||
_pool_healthy = False
|
||||
logger.exception("Failed to reconnect asyncpg pools")
|
||||
raise
|
||||
|
||||
|
||||
async def _init_db():
|
||||
"""Create pgvector extension and memories table if not exists."""
|
||||
global pool
|
||||
global pool, _pool_healthy
|
||||
# Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool
|
||||
owner_conn = await asyncpg.connect(OWNER_DSN)
|
||||
conn = owner_conn
|
||||
@@ -185,7 +213,7 @@ async def _init_db():
|
||||
await conn.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_embedding
|
||||
ON memories USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 10)
|
||||
WITH (lists = 100)
|
||||
""")
|
||||
# Conversation chunks table for RAG over chat history
|
||||
await conn.execute(f"""
|
||||
@@ -216,6 +244,7 @@ async def _init_db():
|
||||
# Owner pool for admin queries (bypasses RLS) — 1 connection only
|
||||
global owner_pool
|
||||
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
||||
_pool_healthy = True
|
||||
logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF")
|
||||
|
||||
|
||||
@@ -234,7 +263,9 @@ async def shutdown():
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
if owner_pool:
|
||||
global _pool_healthy
|
||||
try:
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
||||
@@ -244,7 +275,10 @@ async def health():
|
||||
"total_chunks": chunk_count,
|
||||
"encryption": "on" if ENCRYPTION_KEY else "off",
|
||||
}
|
||||
return {"status": "no_db"}
|
||||
except Exception as e:
|
||||
_pool_healthy = False
|
||||
logger.error("Health check failed: %s", e)
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
|
||||
@app.post("/memories/store")
|
||||
@@ -256,6 +290,7 @@ async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
|
||||
embedding = await _embed(req.fact)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
@@ -291,6 +326,7 @@ async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
@@ -321,6 +357,7 @@ async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
||||
@app.delete("/memories/{user_id}")
|
||||
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""GDPR delete — remove all memories for a user."""
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
|
||||
@@ -332,6 +369,7 @@ async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
@app.get("/memories/{user_id}")
|
||||
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""List all memories for a user (for UI/debug)."""
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
rows = await conn.fetch(
|
||||
@@ -369,6 +407,7 @@ async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
|
||||
encrypted_text = _encrypt(req.chunk_text, req.user_id)
|
||||
encrypted_summary = _encrypt(req.summary, req.user_id)
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
await conn.execute(
|
||||
@@ -403,6 +442,7 @@ async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
|
||||
where = f"WHERE {' AND '.join(conditions)}"
|
||||
params.append(req.top_k)
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
@@ -452,6 +492,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify
|
||||
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
|
||||
continue
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
for chunk, embedding in zip(batch, embeddings):
|
||||
await _set_rls_user(conn, chunk.user_id)
|
||||
@@ -477,6 +518,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify
|
||||
@app.get("/chunks/{user_id}/count")
|
||||
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
|
||||
"""Count conversation chunks for a user."""
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
count = await conn.fetchval(
|
||||
|
||||
Reference in New Issue
Block a user