diff --git a/docker-compose.yml b/docker-compose.yml index f7fc2b8..a276a66 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -37,6 +37,13 @@ services: POSTGRES_DB: memories volumes: - memory-pgdata:/var/lib/postgresql/data + - ./memory-db-ssl/server.crt:/var/lib/postgresql/server.crt:ro + - ./memory-db-ssl/server.key:/var/lib/postgresql/server.key:ro + command: > + postgres + -c ssl=on + -c ssl_cert_file=/var/lib/postgresql/server.crt + -c ssl_key_file=/var/lib/postgresql/server.key healthcheck: test: ["CMD-SHELL", "pg_isready -U memory -d memories"] interval: 5s @@ -47,7 +54,9 @@ services: build: ./memory-service restart: unless-stopped environment: - DATABASE_URL: postgresql://memory:${MEMORY_DB_PASSWORD:-memory}@memory-db:5432/memories + DATABASE_URL: postgresql://memory_app:${MEMORY_APP_PASSWORD}@memory-db:5432/memories?sslmode=require + MEMORY_ENCRYPTION_KEY: ${MEMORY_ENCRYPTION_KEY} + MEMORY_DB_OWNER_PASSWORD: ${MEMORY_DB_PASSWORD:-memory} LITELLM_BASE_URL: ${LITELLM_BASE_URL} LITELLM_API_KEY: ${LITELLM_MASTER_KEY} EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small} diff --git a/memory-service/Dockerfile b/memory-service/Dockerfile index 6df5823..3434539 100644 --- a/memory-service/Dockerfile +++ b/memory-service/Dockerfile @@ -2,5 +2,5 @@ FROM python:3.11-slim WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt -COPY main.py . +COPY main.py migrate_encrypt.py ./ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8090"] diff --git a/memory-service/main.py b/memory-service/main.py index f550488..8dde4ec 100644 --- a/memory-service/main.py +++ b/memory-service/main.py @@ -1,9 +1,12 @@ import os import logging import time +import hashlib +import base64 import asyncpg import httpx +from cryptography.fernet import Fernet from fastapi import FastAPI, HTTPException from pydantic import BaseModel @@ -11,14 +14,49 @@ logger = logging.getLogger("memory-service") logging.basicConfig(level=logging.INFO) DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories") +OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format( + pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory") +)) LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small") EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536")) DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92")) +ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") app = FastAPI(title="Memory Service") pool: asyncpg.Pool | None = None +owner_pool: asyncpg.Pool | None = None + + +def _derive_user_key(user_id: str) -> bytes: + """Derive a per-user Fernet key from master key + user_id via HMAC-SHA256.""" + if not ENCRYPTION_KEY: + raise RuntimeError("MEMORY_ENCRYPTION_KEY not set") + derived = hashlib.pbkdf2_hmac( + "sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1 + ) + return base64.urlsafe_b64encode(derived) + + +def _encrypt(text: str, user_id: str) -> str: + """Encrypt text with per-user Fernet key. Returns base64 ciphertext.""" + if not ENCRYPTION_KEY: + return text + f = Fernet(_derive_user_key(user_id)) + return f.encrypt(text.encode()).decode() + + +def _decrypt(ciphertext: str, user_id: str) -> str: + """Decrypt ciphertext with per-user Fernet key.""" + if not ENCRYPTION_KEY: + return ciphertext + try: + f = Fernet(_derive_user_key(user_id)) + return f.decrypt(ciphertext.encode()).decode() + except Exception: + # Plaintext fallback for not-yet-migrated rows + return ciphertext class StoreRequest(BaseModel): @@ -78,11 +116,18 @@ async def _embed_batch(texts: list[str]) -> list[list[float]]: return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])] +async def _set_rls_user(conn, user_id: str): + """Set the RLS session variable for the current connection.""" + await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id) + + async def _init_db(): """Create pgvector extension and memories table if not exists.""" global pool - pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) - async with pool.acquire() as conn: + # Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool + owner_conn = await asyncpg.connect(OWNER_DSN) + conn = owner_conn + try: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") await conn.execute(f""" CREATE TABLE IF NOT EXISTS memories ( @@ -121,7 +166,14 @@ async def _init_db(): await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id) """) - logger.info("Database initialized (dims=%d)", EMBED_DIMS) + finally: + await owner_conn.close() + # Create restricted pool for all request handlers (RLS applies) + pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10) + # Owner pool for admin queries (bypasses RLS) — 1 connection only + global owner_pool + owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2) + logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF") @app.on_event("startup") @@ -133,21 +185,28 @@ async def startup(): async def shutdown(): if pool: await pool.close() + if owner_pool: + await owner_pool.close() @app.get("/health") async def health(): - if pool: - async with pool.acquire() as conn: + if owner_pool: + async with owner_pool.acquire() as conn: mem_count = await conn.fetchval("SELECT count(*) FROM memories") chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks") - return {"status": "ok", "total_memories": mem_count, "total_chunks": chunk_count} + return { + "status": "ok", + "total_memories": mem_count, + "total_chunks": chunk_count, + "encryption": "on" if ENCRYPTION_KEY else "off", + } return {"status": "no_db"} @app.post("/memories/store") async def store_memory(req: StoreRequest): - """Embed fact, deduplicate by cosine similarity, insert.""" + """Embed fact, deduplicate by cosine similarity, insert encrypted.""" if not req.fact.strip(): raise HTTPException(400, "Empty fact") @@ -155,6 +214,8 @@ async def store_memory(req: StoreRequest): vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: + await _set_rls_user(conn, req.user_id) + # Check for duplicates (cosine similarity > threshold) dup = await conn.fetchval( """ @@ -169,12 +230,13 @@ async def store_memory(req: StoreRequest): logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup) return {"stored": False, "reason": "duplicate"} + encrypted_fact = _encrypt(req.fact.strip(), req.user_id) await conn.execute( """ INSERT INTO memories (user_id, fact, source_room, created_at, embedding) VALUES ($1, $2, $3, $4, $5::vector) """, - req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal, + req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal, ) logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60]) return {"stored": True} @@ -187,6 +249,8 @@ async def query_memories(req: QueryRequest): vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" async with pool.acquire() as conn: + await _set_rls_user(conn, req.user_id) + rows = await conn.fetch( """ SELECT fact, source_room, created_at, @@ -201,7 +265,7 @@ async def query_memories(req: QueryRequest): results = [ { - "fact": r["fact"], + "fact": _decrypt(r["fact"], req.user_id), "source_room": r["source_room"], "created_at": r["created_at"], "similarity": float(r["similarity"]), @@ -215,6 +279,7 @@ async def query_memories(req: QueryRequest): async def delete_user_memories(user_id: str): """GDPR delete — remove all memories for a user.""" async with pool.acquire() as conn: + await _set_rls_user(conn, user_id) result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id) count = int(result.split()[-1]) logger.info("Deleted %d memories for %s", count, user_id) @@ -225,6 +290,7 @@ async def delete_user_memories(user_id: str): async def list_user_memories(user_id: str): """List all memories for a user (for UI/debug).""" async with pool.acquire() as conn: + await _set_rls_user(conn, user_id) rows = await conn.fetch( """ SELECT fact, source_room, created_at @@ -238,7 +304,7 @@ async def list_user_memories(user_id: str): "user_id": user_id, "count": len(rows), "memories": [ - {"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]} + {"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]} for r in rows ], } @@ -249,7 +315,7 @@ async def list_user_memories(user_id: str): @app.post("/chunks/store") async def store_chunk(req: ChunkStoreRequest): - """Store a conversation chunk with its summary embedding.""" + """Store a conversation chunk with its summary embedding, encrypted.""" if not req.summary.strip(): raise HTTPException(400, "Empty summary") @@ -257,14 +323,18 @@ async def store_chunk(req: ChunkStoreRequest): vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = req.original_ts or time.time() + encrypted_text = _encrypt(req.chunk_text, req.user_id) + encrypted_summary = _encrypt(req.summary, req.user_id) + async with pool.acquire() as conn: + await _set_rls_user(conn, req.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, - req.user_id, req.room_id, req.chunk_text, req.summary, + req.user_id, req.room_id, encrypted_text, encrypted_summary, req.source_event_id, ts, vec_literal, ) logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60]) @@ -294,7 +364,13 @@ async def query_chunks(req: ChunkQueryRequest): where = f"WHERE {' AND '.join(conditions)}" if conditions else "" params.append(req.top_k) + # Determine user_id for RLS and decryption + rls_user = req.user_id if req.user_id else "" + async with pool.acquire() as conn: + if rls_user: + await _set_rls_user(conn, rls_user) + rows = await conn.fetch( f""" SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id, @@ -309,8 +385,8 @@ async def query_chunks(req: ChunkQueryRequest): results = [ { - "chunk_text": r["chunk_text"], - "summary": r["summary"], + "chunk_text": _decrypt(r["chunk_text"], r["user_id"]), + "summary": _decrypt(r["summary"], r["user_id"]), "room_id": r["room_id"], "user_id": r["user_id"], "original_ts": r["original_ts"], @@ -343,15 +419,18 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest): async with pool.acquire() as conn: for chunk, embedding in zip(batch, embeddings): + await _set_rls_user(conn, chunk.user_id) vec_literal = "[" + ",".join(str(v) for v in embedding) + "]" ts = chunk.original_ts or time.time() + encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id) + encrypted_summary = _encrypt(chunk.summary, chunk.user_id) await conn.execute( """ INSERT INTO conversation_chunks (user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector) """, - chunk.user_id, chunk.room_id, chunk.chunk_text, chunk.summary, + chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary, chunk.source_event_id, ts, vec_literal, ) stored += 1 @@ -364,6 +443,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest): async def count_user_chunks(user_id: str): """Count conversation chunks for a user.""" async with pool.acquire() as conn: + await _set_rls_user(conn, user_id) count = await conn.fetchval( "SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id, ) diff --git a/memory-service/migrate_encrypt.py b/memory-service/migrate_encrypt.py new file mode 100644 index 0000000..d0f1152 --- /dev/null +++ b/memory-service/migrate_encrypt.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +"""MAT-107: One-time migration to encrypt existing plaintext memory data. + +Run INSIDE the memory-service container after deploying new code: + docker exec -it matrix-ai-agent-memory-service-1 python migrate_encrypt.py + +Connects as owner (memory) to bypass RLS. +""" +import os +import sys +import hashlib +import base64 +import asyncio + +import asyncpg +from cryptography.fernet import Fernet, InvalidToken + +OWNER_DSN = os.environ.get( + "OWNER_DATABASE_URL", + "postgresql://memory:{password}@memory-db:5432/memories".format( + password=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory") + ), +) +ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "") + + +def _derive_user_key(user_id: str) -> bytes: + derived = hashlib.pbkdf2_hmac("sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1) + return base64.urlsafe_b64encode(derived) + + +def _encrypt(text: str, user_id: str) -> str: + f = Fernet(_derive_user_key(user_id)) + return f.encrypt(text.encode()).decode() + + +def _is_encrypted(text: str, user_id: str) -> bool: + """Check if text is already Fernet-encrypted.""" + try: + f = Fernet(_derive_user_key(user_id)) + f.decrypt(text.encode()) + return True + except (InvalidToken, Exception): + return False + + +async def migrate(): + if not ENCRYPTION_KEY: + print("ERROR: MEMORY_ENCRYPTION_KEY not set") + sys.exit(1) + + conn = await asyncpg.connect(OWNER_DSN) + + # Migrate memories + rows = await conn.fetch("SELECT id, user_id, fact FROM memories ORDER BY id") + print(f"Migrating {len(rows)} memories...") + encrypted = 0 + skipped = 0 + for row in rows: + if _is_encrypted(row["fact"], row["user_id"]): + skipped += 1 + continue + enc_fact = _encrypt(row["fact"], row["user_id"]) + await conn.execute("UPDATE memories SET fact = $1 WHERE id = $2", enc_fact, row["id"]) + encrypted += 1 + if encrypted % 100 == 0: + print(f" memories: {encrypted}/{len(rows)} encrypted") + print(f"Memories done: {encrypted} encrypted, {skipped} already encrypted") + + # Migrate conversation_chunks + rows = await conn.fetch("SELECT id, user_id, chunk_text, summary FROM conversation_chunks ORDER BY id") + print(f"Migrating {len(rows)} chunks...") + encrypted = 0 + skipped = 0 + for row in rows: + if _is_encrypted(row["chunk_text"], row["user_id"]): + skipped += 1 + continue + enc_text = _encrypt(row["chunk_text"], row["user_id"]) + enc_summary = _encrypt(row["summary"], row["user_id"]) + await conn.execute( + "UPDATE conversation_chunks SET chunk_text = $1, summary = $2 WHERE id = $3", + enc_text, enc_summary, row["id"], + ) + encrypted += 1 + if encrypted % 500 == 0: + print(f" chunks: {encrypted}/{len(rows)} encrypted") + print(f"Chunks done: {encrypted} encrypted, {skipped} already encrypted") + + await conn.close() + print("Migration complete.") + + +if __name__ == "__main__": + asyncio.run(migrate()) diff --git a/memory-service/requirements.txt b/memory-service/requirements.txt index 5f17c58..9a9cae9 100644 --- a/memory-service/requirements.txt +++ b/memory-service/requirements.txt @@ -3,3 +3,4 @@ uvicorn>=0.34,<1.0 asyncpg>=0.30,<1.0 pgvector>=0.3,<1.0 httpx>=0.27,<1.0 +cryptography>=44.0,<45.0 diff --git a/migrate_rls.sql b/migrate_rls.sql new file mode 100644 index 0000000..d8e2bdc --- /dev/null +++ b/migrate_rls.sql @@ -0,0 +1,40 @@ +-- MAT-107: Row-Level Security for memory tables +-- Run as superuser (memory) which owns the tables + +-- Create restricted role for memory-service +DO $$ +BEGIN + IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'memory_app') THEN + CREATE ROLE memory_app LOGIN PASSWORD 'OhugBZP4g4d7rk3OszOq1Xe3yo7hQwEn'; + END IF; +END +$$; + +-- Grant permissions to memory_app +GRANT CONNECT ON DATABASE memories TO memory_app; +GRANT USAGE ON SCHEMA public TO memory_app; +GRANT SELECT, INSERT, DELETE ON memories, conversation_chunks TO memory_app; +GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO memory_app; + +-- Ensure tables are owned by memory (superuser) so RLS doesn't apply to owner +ALTER TABLE memories OWNER TO memory; +ALTER TABLE conversation_chunks OWNER TO memory; + +-- Enable RLS +ALTER TABLE memories ENABLE ROW LEVEL SECURITY; +ALTER TABLE conversation_chunks ENABLE ROW LEVEL SECURITY; + +-- Drop existing policies if re-running +DROP POLICY IF EXISTS user_isolation_memories ON memories; +DROP POLICY IF EXISTS user_isolation_chunks ON conversation_chunks; + +-- RLS policies: rows visible only when session var matches user_id +-- current_setting with missing_ok=true returns empty string if not set +CREATE POLICY user_isolation_memories ON memories + USING (user_id = current_setting('app.current_user_id', true)); + +CREATE POLICY user_isolation_chunks ON conversation_chunks + USING (user_id = current_setting('app.current_user_id', true)); + +-- Verify +SELECT tablename, rowsecurity FROM pg_tables WHERE tablename IN ('memories', 'conversation_chunks'); diff --git a/setup_ssl.sh b/setup_ssl.sh new file mode 100644 index 0000000..d598b00 --- /dev/null +++ b/setup_ssl.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# MAT-107: Generate self-signed SSL cert for memory-db and configure postgres +set -euo pipefail + +SSL_DIR="/opt/matrix-ai-agent/memory-db-ssl" +mkdir -p "$SSL_DIR" + +# Generate self-signed cert (valid 10 years) +openssl req -new -x509 -days 3650 -nodes \ + -subj "/CN=memory-db" \ + -keyout "$SSL_DIR/server.key" \ + -out "$SSL_DIR/server.crt" \ + 2>/dev/null + +# Postgres requires specific permissions +chmod 600 "$SSL_DIR/server.key" +chmod 644 "$SSL_DIR/server.crt" +# Postgres runs as uid 999 in the pgvector container +chown 999:999 "$SSL_DIR/server.key" "$SSL_DIR/server.crt" + +echo "SSL certs generated in $SSL_DIR"