import os import logging import secrets import time import hashlib import base64 import asyncpg import httpx from cryptography.fernet import Fernet from fastapi import Depends, FastAPI, Header, HTTPException from pydantic import BaseModel, field_validator logger = logging.getLogger("memory-service") logging.basicConfig(level=logging.INFO) DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format( pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory") )) LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "") app = FastAPI(title="Memory Service") pool: asyncpg.Pool | None = None owner_pool: asyncpg.Pool | None = None _pool_healthy = True async def verify_token(authorization: str | None = Header(None)): """Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode).""" if not MEMORY_SERVICE_TOKEN: return if not authorization or not authorization.startswith("Bearer "): raise HTTPException(401, "Missing or invalid Authorization header") if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN): raise HTTPException(403, "Invalid token") def _derive_user_key(user_id: str) -> bytes: """Derive a per-user Fernet key from master key + user_id via HMAC-SHA256.""" if not ENCRYPTION_KEY: raise RuntimeError("MEMORY_ENCRYPTION_KEY not set") derived = hashlib.pbkdf2_hmac( "sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1 ) return base64.urlsafe_b64encode(derived) def _encrypt(text: str, user_id: str) -> str: """Encrypt text with per-user Fernet key. Returns base64 ciphertext.""" if not ENCRYPTION_KEY: return text f = Fernet(_derive_user_key(user_id)) return f.encrypt(text.encode()).decode() def _decrypt(ciphertext: str, user_id: str) -> str: """Decrypt ciphertext with per-user Fernet key.""" if not ENCRYPTION_KEY: return ciphertext try: f = Fernet(_derive_user_key(user_id)) return f.decrypt(ciphertext.encode()).decode() except Exception: # Plaintext fallback for not-yet-migrated rows return ciphertext class ScheduleRequest(BaseModel): user_id: str room_id: str message_text: str scheduled_at: float # Unix timestamp repeat_pattern: str = "once" # once | daily | weekly | weekdays | monthly @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() @field_validator('repeat_pattern') @classmethod def valid_pattern(cls, v): allowed = {"once", "daily", "weekly", "weekdays", "monthly"} if v not in allowed: raise ValueError(f"repeat_pattern must be one of {allowed}") return v class ScheduleCancelRequest(BaseModel): id: int user_id: str class StoreRequest(BaseModel): user_id: str fact: str source_room: str = "" @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class QueryRequest(BaseModel): user_id: str query: str top_k: int = 10 @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class ChunkStoreRequest(BaseModel): user_id: str room_id: str chunk_text: str summary: str source_event_id: str = "" original_ts: float = 0.0 @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class ChunkQueryRequest(BaseModel): user_id: str # REQUIRED — no default room_id: str = "" query: str top_k: int = 5 @field_validator('user_id') @classmethod def user_id_not_empty(cls, v): if not v or not v.strip(): raise ValueError("user_id is required") return v.strip() class ChunkBulkStoreRequest(BaseModel): chunks: list[ChunkStoreRequest] async def _embed(text: str) -> list[float]: """Get embedding vector from LiteLLM /embeddings endpoint.""" async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": text}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() return resp.json()["data"][0]["embedding"] async def _embed_batch(texts: list[str]) -> list[list[float]]: """Get embedding vectors for a batch of texts.""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post( f"{LITELLM_URL}/embeddings", json={"model": EMBED_MODEL, "input": texts}, headers={"Authorization": f"Bearer {LITELLM_KEY}"}, ) resp.raise_for_status() data = resp.json()["data"] return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] async def _set_rls_user(conn, user_id: str): """Set the RLS session variable for the current connection.""" await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id) async def _ensure_pool(): """Recreate the connection pool if it was lost.""" global pool, owner_pool, _pool_healthy if pool and _pool_healthy: return logger.warning("Reconnecting asyncpg pools (healthy=%s, pool=%s)", _pool_healthy, pool is not None) try: if pool: try: await pool.close() except Exception: pass if owner_pool: try: await owner_pool.close() except Exception: pass pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2) _pool_healthy = True logger.info("asyncpg pools reconnected successfully") except Exception: _pool_healthy = False logger.exception("Failed to reconnect asyncpg pools") raise async def _init_db(): """Create pgvector extension and memories table if not exists.""" global pool, _pool_healthy # Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool owner_conn = await asyncpg.connect(OWNER_DSN) conn = owner_conn try: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") await conn.execute(f""" CREATE TABLE IF NOT EXISTS memories ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, fact TEXT NOT NULL, source_room TEXT DEFAULT '', created_at DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id) """) await conn.execute(f""" CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100) """) # Conversation chunks table for RAG over chat history await conn.execute(f""" CREATE TABLE IF NOT EXISTS conversation_chunks ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, room_id TEXT NOT NULL, chunk_text TEXT NOT NULL, summary TEXT NOT NULL, source_event_id TEXT DEFAULT '', original_ts DOUBLE PRECISION NOT NULL, embedding vector({EMBED_DIMS}) NOT NULL ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id) """) # Scheduled messages table for reminders await conn.execute(""" CREATE TABLE IF NOT EXISTS scheduled_messages ( id BIGSERIAL PRIMARY KEY, user_id TEXT NOT NULL, room_id TEXT NOT NULL, message_text TEXT NOT NULL, scheduled_at DOUBLE PRECISION NOT NULL, created_at DOUBLE PRECISION NOT NULL, status TEXT DEFAULT 'pending', repeat_pattern TEXT DEFAULT 'once', repeat_interval_seconds INTEGER DEFAULT 0, last_sent_at DOUBLE PRECISION DEFAULT 0 ) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_scheduled_user_id ON scheduled_messages (user_id) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_scheduled_status ON scheduled_messages (status, scheduled_at) """) finally: await owner_conn.close() # Create restricted pool for all request handlers (RLS applies) pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) # Owner pool for admin queries (bypasses RLS) — 1 connection only global owner_pool owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2) _pool_healthy = True logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF") @app.on_event("startup") async def startup(): await _init_db() @app.on_event("shutdown") async def shutdown(): if pool: await pool.close() if owner_pool: await owner_pool.close() @app.get("/health") async def health(): global _pool_healthy try: await _ensure_pool() async with owner_pool.acquire() as conn: mem_count = await conn.fetchval("SELECT count(*) FROM memories") chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks") sched_count = await conn.fetchval("SELECT count(*) FROM scheduled_messages WHERE status = 'pending'") return { "status": "ok", "total_memories": mem_count, "total_chunks": chunk_count, "pending_reminders": sched_count, "encryption": "on" if ENCRYPTION_KEY else "off", } except Exception as e: _pool_healthy = False logger.error("Health check failed: %s", e) return {"status": "unhealthy", "error": str(e)} @app.post("/memories/store") async def store_memory(req: StoreRequest, _: None = Depends(verify_token)): """Embed fact, deduplicate by cosine similarity, insert encrypted.""" if not req.fact.strip(): raise HTTPException(400, "Empty fact") embedding = await _embed(req.fact) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) # Check for duplicates (cosine similarity > threshold) dup = await conn.fetchval( """ SELECT id FROM memories WHERE user_id = $1 AND 1 - (embedding <=> $2::vector) > $3 LIMIT 1 """, req.user_id, vec_literal, DEDUP_THRESHOLD, ) if dup: logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup) return {"stored": False, "reason": "duplicate"} encrypted_fact = _encrypt(req.fact.strip(), req.user_id) await conn.execute( """ INSERT INTO memories (user_id, fact, source_room, created_at, embedding) VALUES ($1, $2, $3, $4, $5::vector) """, req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal, ) logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60]) return {"stored": True} @app.post("/memories/query") async def query_memories(req: QueryRequest, _: None = Depends(verify_token)): """Embed query, return top-K similar facts for user.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at, 1 - (embedding <=> $1::vector) AS similarity FROM memories WHERE user_id = $2 ORDER BY embedding <=> $1::vector LIMIT $3 """, vec_literal, req.user_id, req.top_k, ) results = [ { "fact": _decrypt(r["fact"], req.user_id), "source_room": r["source_room"], "created_at": r["created_at"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.delete("/memories/{user_id}") async def delete_user_memories(user_id: str, _: None = Depends(verify_token)): """GDPR delete — remove all memories for a user.""" await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, user_id) result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id) count = int(result.split()[-1]) logger.info("Deleted %d memories for %s", count, user_id) return {"deleted": count} @app.get("/memories/{user_id}") async def list_user_memories(user_id: str, _: None = Depends(verify_token)): """List all memories for a user (for UI/debug).""" await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at FROM memories WHERE user_id = $1 ORDER BY created_at DESC """, user_id, ) return { "user_id": user_id, "count": len(rows), "memories": [ {"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]} for r in rows ], } # --- Conversation Chunks --- @app.post("/chunks/store") async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)): """Store a conversation chunk with its summary embedding, encrypted.""" if not req.summary.strip(): raise HTTPException(400, "Empty summary") embedding = await _embed(req.summary) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = req.original_ts or time.time() encrypted_text = _encrypt(req.chunk_text, req.user_id) encrypted_summary = _encrypt(req.summary, req.user_id) await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, req.user_id, req.room_id, encrypted_text, encrypted_summary, req.source_event_id, ts, vec_literal, ) logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60]) return {"stored": True} @app.post("/chunks/query") async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)): """Semantic search over conversation chunks. Filter by user_id and/or room_id.""" embedding = await _embed(req.query) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" # Build WHERE clause — user_id is always required conditions = [f"user_id = $2"] params: list = [vec_literal, req.user_id] idx = 3 if req.room_id: conditions.append(f"room_id = ${idx}") params.append(req.room_id) idx += 1 where = f"WHERE {' AND '.join(conditions)}" params.append(req.top_k) await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, req.user_id) rows = await conn.fetch( f""" SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id, 1 - (embedding <=> $1::vector) AS similarity FROM conversation_chunks {where} ORDER BY embedding <=> $1::vector LIMIT ${idx} """, *params, ) results = [ { "chunk_text": _decrypt(r["chunk_text"], r["user_id"]), "summary": _decrypt(r["summary"], r["user_id"]), "room_id": r["room_id"], "user_id": r["user_id"], "original_ts": r["original_ts"], "source_event_id": r["source_event_id"], "similarity": float(r["similarity"]), } for r in rows ] return {"results": results} @app.post("/chunks/bulk-store") async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)): """Batch store conversation chunks. Embeds summaries in batches of 20.""" if not req.chunks: return {"stored": 0} stored = 0 batch_size = 20 for i in range(0, len(req.chunks), batch_size): batch = req.chunks[i:i + batch_size] summaries = [c.summary.strip() for c in batch] try: embeddings = await _embed_batch(summaries) except Exception: logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True) continue await _ensure_pool() async with pool.acquire() as conn: for chunk, embedding in zip(batch, embeddings): await _set_rls_user(conn, chunk.user_id) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = chunk.original_ts or time.time() encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id) encrypted_summary = _encrypt(chunk.summary, chunk.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary, chunk.source_event_id, ts, vec_literal, ) stored += 1 logger.info("Bulk stored %d chunks", stored) return {"stored": stored} @app.get("/chunks/{user_id}/count") async def count_user_chunks(user_id: str, _: None = Depends(verify_token)): """Count conversation chunks for a user.""" await _ensure_pool() async with pool.acquire() as conn: await _set_rls_user(conn, user_id) count = await conn.fetchval( "SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id, ) return {"user_id": user_id, "count": count} # --- Scheduled Messages --- import calendar import datetime def _compute_repeat_interval(pattern: str) -> int: """Compute repeat_interval_seconds from pattern name.""" return { "once": 0, "daily": 86400, "weekly": 604800, "weekdays": 86400, # special handling in mark-sent "monthly": 0, # special handling in mark-sent }.get(pattern, 0) def _next_scheduled_at(current_ts: float, pattern: str) -> float: """Compute the next scheduled_at timestamp for recurring patterns.""" dt = datetime.datetime.fromtimestamp(current_ts, tz=datetime.timezone.utc) if pattern == "daily": return current_ts + 86400.0 elif pattern == "weekly": return current_ts + 604800.0 elif pattern == "weekdays": next_dt = dt + datetime.timedelta(days=1) while next_dt.weekday() >= 5: # Skip Sat(5), Sun(6) next_dt += datetime.timedelta(days=1) return next_dt.timestamp() elif pattern == "monthly": month = dt.month + 1 year = dt.year + (month - 1) // 12 month = (month - 1) % 12 + 1 day = min(dt.day, calendar.monthrange(year, month)[1]) return dt.replace(year=year, month=month, day=day).timestamp() return current_ts MAX_REMINDERS_PER_USER = 50 @app.post("/scheduled/create") async def create_scheduled(req: ScheduleRequest, _: None = Depends(verify_token)): """Create a new scheduled message/reminder.""" now = time.time() if req.scheduled_at <= now: raise HTTPException(400, "scheduled_at must be in the future") # Check max reminders per user await _ensure_pool() async with owner_pool.acquire() as conn: count = await conn.fetchval( "SELECT count(*) FROM scheduled_messages WHERE user_id = $1 AND status = 'pending'", req.user_id, ) if count >= MAX_REMINDERS_PER_USER: raise HTTPException(400, f"Maximum {MAX_REMINDERS_PER_USER} active reminders per user") msg_text = req.message_text[:2000] # Truncate long messages interval = _compute_repeat_interval(req.repeat_pattern) row_id = await conn.fetchval( """ INSERT INTO scheduled_messages (user_id, room_id, message_text, scheduled_at, created_at, status, repeat_pattern, repeat_interval_seconds) VALUES ($1, $2, $3, $4, $5, 'pending', $6, $7) RETURNING id """, req.user_id, req.room_id, msg_text, req.scheduled_at, now, req.repeat_pattern, interval, ) logger.info("Created reminder #%d for %s at %.0f (%s)", row_id, req.user_id, req.scheduled_at, req.repeat_pattern) return {"id": row_id, "created": True} @app.get("/scheduled/{user_id}") async def list_scheduled(user_id: str, _: None = Depends(verify_token)): """List all pending/active reminders for a user.""" await _ensure_pool() async with owner_pool.acquire() as conn: rows = await conn.fetch( """ SELECT id, message_text, scheduled_at, repeat_pattern, status FROM scheduled_messages WHERE user_id = $1 AND status = 'pending' ORDER BY scheduled_at """, user_id, ) return { "user_id": user_id, "reminders": [ { "id": r["id"], "message_text": r["message_text"], "scheduled_at": r["scheduled_at"], "repeat_pattern": r["repeat_pattern"], "status": r["status"], } for r in rows ], } @app.delete("/scheduled/{user_id}/{reminder_id}") async def cancel_scheduled(user_id: str, reminder_id: int, _: None = Depends(verify_token)): """Cancel a reminder. Only the owner can cancel.""" await _ensure_pool() async with owner_pool.acquire() as conn: result = await conn.execute( """ UPDATE scheduled_messages SET status = 'cancelled' WHERE id = $1 AND user_id = $2 AND status = 'pending' """, reminder_id, user_id, ) count = int(result.split()[-1]) if count == 0: raise HTTPException(404, "Reminder not found or already cancelled") logger.info("Cancelled reminder #%d for %s", reminder_id, user_id) return {"cancelled": True, "id": reminder_id} @app.post("/scheduled/due") async def get_due_messages(_: None = Depends(verify_token)): """Return all messages that are due (scheduled_at <= now, status = pending).""" now = time.time() await _ensure_pool() async with owner_pool.acquire() as conn: rows = await conn.fetch( """ SELECT id, user_id, room_id, message_text, scheduled_at, repeat_pattern FROM scheduled_messages WHERE scheduled_at <= $1 AND status = 'pending' ORDER BY scheduled_at LIMIT 100 """, now, ) return { "due": [ { "id": r["id"], "user_id": r["user_id"], "room_id": r["room_id"], "message_text": r["message_text"], "scheduled_at": r["scheduled_at"], "repeat_pattern": r["repeat_pattern"], } for r in rows ], } @app.post("/scheduled/{reminder_id}/mark-sent") async def mark_sent(reminder_id: int, _: None = Depends(verify_token)): """Mark a reminder as sent. For recurring, compute next scheduled_at.""" now = time.time() await _ensure_pool() async with owner_pool.acquire() as conn: row = await conn.fetchrow( "SELECT repeat_pattern, scheduled_at FROM scheduled_messages WHERE id = $1", reminder_id, ) if not row: raise HTTPException(404, "Reminder not found") if row["repeat_pattern"] == "once": await conn.execute( "UPDATE scheduled_messages SET status = 'sent', last_sent_at = $1 WHERE id = $2", now, reminder_id, ) else: next_at = _next_scheduled_at(row["scheduled_at"], row["repeat_pattern"]) await conn.execute( """ UPDATE scheduled_messages SET scheduled_at = $1, last_sent_at = $2 WHERE id = $3 """, next_at, now, reminder_id, ) logger.info("Marked reminder #%d as sent (pattern=%s)", reminder_id, row["repeat_pattern"]) return {"marked": True}