feat(MAT-107): memory encryption & user isolation
- Per-user Fernet encryption for fact/chunk_text/summary fields - Postgres RLS with memory_app restricted role - SSL for memory-db connections - Data migration script (migrate_encrypt.py) - DB migration (migrate_rls.sql) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,13 @@ services:
|
|||||||
POSTGRES_DB: memories
|
POSTGRES_DB: memories
|
||||||
volumes:
|
volumes:
|
||||||
- memory-pgdata:/var/lib/postgresql/data
|
- memory-pgdata:/var/lib/postgresql/data
|
||||||
|
- ./memory-db-ssl/server.crt:/var/lib/postgresql/server.crt:ro
|
||||||
|
- ./memory-db-ssl/server.key:/var/lib/postgresql/server.key:ro
|
||||||
|
command: >
|
||||||
|
postgres
|
||||||
|
-c ssl=on
|
||||||
|
-c ssl_cert_file=/var/lib/postgresql/server.crt
|
||||||
|
-c ssl_key_file=/var/lib/postgresql/server.key
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U memory -d memories"]
|
test: ["CMD-SHELL", "pg_isready -U memory -d memories"]
|
||||||
interval: 5s
|
interval: 5s
|
||||||
@@ -47,7 +54,9 @@ services:
|
|||||||
build: ./memory-service
|
build: ./memory-service
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql://memory:${MEMORY_DB_PASSWORD:-memory}@memory-db:5432/memories
|
DATABASE_URL: postgresql://memory_app:${MEMORY_APP_PASSWORD}@memory-db:5432/memories?sslmode=require
|
||||||
|
MEMORY_ENCRYPTION_KEY: ${MEMORY_ENCRYPTION_KEY}
|
||||||
|
MEMORY_DB_OWNER_PASSWORD: ${MEMORY_DB_PASSWORD:-memory}
|
||||||
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
||||||
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
|
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
|
||||||
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
||||||
|
|||||||
@@ -2,5 +2,5 @@ FROM python:3.11-slim
|
|||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
COPY main.py .
|
COPY main.py migrate_encrypt.py ./
|
||||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8090"]
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
|
||||||
import asyncpg
|
import asyncpg
|
||||||
import httpx
|
import httpx
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -11,14 +14,49 @@ logger = logging.getLogger("memory-service")
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
|
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
|
||||||
|
OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format(
|
||||||
|
pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
|
||||||
|
))
|
||||||
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||||
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
||||||
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
||||||
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
||||||
|
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||||
|
|
||||||
app = FastAPI(title="Memory Service")
|
app = FastAPI(title="Memory Service")
|
||||||
pool: asyncpg.Pool | None = None
|
pool: asyncpg.Pool | None = None
|
||||||
|
owner_pool: asyncpg.Pool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_user_key(user_id: str) -> bytes:
|
||||||
|
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
||||||
|
if not ENCRYPTION_KEY:
|
||||||
|
raise RuntimeError("MEMORY_ENCRYPTION_KEY not set")
|
||||||
|
derived = hashlib.pbkdf2_hmac(
|
||||||
|
"sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1
|
||||||
|
)
|
||||||
|
return base64.urlsafe_b64encode(derived)
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(text: str, user_id: str) -> str:
|
||||||
|
"""Encrypt text with per-user Fernet key. Returns base64 ciphertext."""
|
||||||
|
if not ENCRYPTION_KEY:
|
||||||
|
return text
|
||||||
|
f = Fernet(_derive_user_key(user_id))
|
||||||
|
return f.encrypt(text.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt(ciphertext: str, user_id: str) -> str:
|
||||||
|
"""Decrypt ciphertext with per-user Fernet key."""
|
||||||
|
if not ENCRYPTION_KEY:
|
||||||
|
return ciphertext
|
||||||
|
try:
|
||||||
|
f = Fernet(_derive_user_key(user_id))
|
||||||
|
return f.decrypt(ciphertext.encode()).decode()
|
||||||
|
except Exception:
|
||||||
|
# Plaintext fallback for not-yet-migrated rows
|
||||||
|
return ciphertext
|
||||||
|
|
||||||
|
|
||||||
class StoreRequest(BaseModel):
|
class StoreRequest(BaseModel):
|
||||||
@@ -78,11 +116,18 @@ async def _embed_batch(texts: list[str]) -> list[list[float]]:
|
|||||||
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
|
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_rls_user(conn, user_id: str):
|
||||||
|
"""Set the RLS session variable for the current connection."""
|
||||||
|
await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _init_db():
|
async def _init_db():
|
||||||
"""Create pgvector extension and memories table if not exists."""
|
"""Create pgvector extension and memories table if not exists."""
|
||||||
global pool
|
global pool
|
||||||
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
# Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool
|
||||||
async with pool.acquire() as conn:
|
owner_conn = await asyncpg.connect(OWNER_DSN)
|
||||||
|
conn = owner_conn
|
||||||
|
try:
|
||||||
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
await conn.execute(f"""
|
await conn.execute(f"""
|
||||||
CREATE TABLE IF NOT EXISTS memories (
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
@@ -121,7 +166,14 @@ async def _init_db():
|
|||||||
await conn.execute("""
|
await conn.execute("""
|
||||||
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||||
""")
|
""")
|
||||||
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
|
finally:
|
||||||
|
await owner_conn.close()
|
||||||
|
# Create restricted pool for all request handlers (RLS applies)
|
||||||
|
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
||||||
|
# Owner pool for admin queries (bypasses RLS) — 1 connection only
|
||||||
|
global owner_pool
|
||||||
|
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
||||||
|
logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF")
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
@@ -133,21 +185,28 @@ async def startup():
|
|||||||
async def shutdown():
|
async def shutdown():
|
||||||
if pool:
|
if pool:
|
||||||
await pool.close()
|
await pool.close()
|
||||||
|
if owner_pool:
|
||||||
|
await owner_pool.close()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
if pool:
|
if owner_pool:
|
||||||
async with pool.acquire() as conn:
|
async with owner_pool.acquire() as conn:
|
||||||
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||||
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
||||||
return {"status": "ok", "total_memories": mem_count, "total_chunks": chunk_count}
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"total_memories": mem_count,
|
||||||
|
"total_chunks": chunk_count,
|
||||||
|
"encryption": "on" if ENCRYPTION_KEY else "off",
|
||||||
|
}
|
||||||
return {"status": "no_db"}
|
return {"status": "no_db"}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/memories/store")
|
@app.post("/memories/store")
|
||||||
async def store_memory(req: StoreRequest):
|
async def store_memory(req: StoreRequest):
|
||||||
"""Embed fact, deduplicate by cosine similarity, insert."""
|
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
||||||
if not req.fact.strip():
|
if not req.fact.strip():
|
||||||
raise HTTPException(400, "Empty fact")
|
raise HTTPException(400, "Empty fact")
|
||||||
|
|
||||||
@@ -155,6 +214,8 @@ async def store_memory(req: StoreRequest):
|
|||||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
await _set_rls_user(conn, req.user_id)
|
||||||
|
|
||||||
# Check for duplicates (cosine similarity > threshold)
|
# Check for duplicates (cosine similarity > threshold)
|
||||||
dup = await conn.fetchval(
|
dup = await conn.fetchval(
|
||||||
"""
|
"""
|
||||||
@@ -169,12 +230,13 @@ async def store_memory(req: StoreRequest):
|
|||||||
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
|
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
|
||||||
return {"stored": False, "reason": "duplicate"}
|
return {"stored": False, "reason": "duplicate"}
|
||||||
|
|
||||||
|
encrypted_fact = _encrypt(req.fact.strip(), req.user_id)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
|
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
|
||||||
VALUES ($1, $2, $3, $4, $5::vector)
|
VALUES ($1, $2, $3, $4, $5::vector)
|
||||||
""",
|
""",
|
||||||
req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal,
|
req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal,
|
||||||
)
|
)
|
||||||
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
|
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
|
||||||
return {"stored": True}
|
return {"stored": True}
|
||||||
@@ -187,6 +249,8 @@ async def query_memories(req: QueryRequest):
|
|||||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
await _set_rls_user(conn, req.user_id)
|
||||||
|
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
"""
|
"""
|
||||||
SELECT fact, source_room, created_at,
|
SELECT fact, source_room, created_at,
|
||||||
@@ -201,7 +265,7 @@ async def query_memories(req: QueryRequest):
|
|||||||
|
|
||||||
results = [
|
results = [
|
||||||
{
|
{
|
||||||
"fact": r["fact"],
|
"fact": _decrypt(r["fact"], req.user_id),
|
||||||
"source_room": r["source_room"],
|
"source_room": r["source_room"],
|
||||||
"created_at": r["created_at"],
|
"created_at": r["created_at"],
|
||||||
"similarity": float(r["similarity"]),
|
"similarity": float(r["similarity"]),
|
||||||
@@ -215,6 +279,7 @@ async def query_memories(req: QueryRequest):
|
|||||||
async def delete_user_memories(user_id: str):
|
async def delete_user_memories(user_id: str):
|
||||||
"""GDPR delete — remove all memories for a user."""
|
"""GDPR delete — remove all memories for a user."""
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
await _set_rls_user(conn, user_id)
|
||||||
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
|
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
|
||||||
count = int(result.split()[-1])
|
count = int(result.split()[-1])
|
||||||
logger.info("Deleted %d memories for %s", count, user_id)
|
logger.info("Deleted %d memories for %s", count, user_id)
|
||||||
@@ -225,6 +290,7 @@ async def delete_user_memories(user_id: str):
|
|||||||
async def list_user_memories(user_id: str):
|
async def list_user_memories(user_id: str):
|
||||||
"""List all memories for a user (for UI/debug)."""
|
"""List all memories for a user (for UI/debug)."""
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
await _set_rls_user(conn, user_id)
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
"""
|
"""
|
||||||
SELECT fact, source_room, created_at
|
SELECT fact, source_room, created_at
|
||||||
@@ -238,7 +304,7 @@ async def list_user_memories(user_id: str):
|
|||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"count": len(rows),
|
"count": len(rows),
|
||||||
"memories": [
|
"memories": [
|
||||||
{"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]}
|
{"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]}
|
||||||
for r in rows
|
for r in rows
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -249,7 +315,7 @@ async def list_user_memories(user_id: str):
|
|||||||
|
|
||||||
@app.post("/chunks/store")
|
@app.post("/chunks/store")
|
||||||
async def store_chunk(req: ChunkStoreRequest):
|
async def store_chunk(req: ChunkStoreRequest):
|
||||||
"""Store a conversation chunk with its summary embedding."""
|
"""Store a conversation chunk with its summary embedding, encrypted."""
|
||||||
if not req.summary.strip():
|
if not req.summary.strip():
|
||||||
raise HTTPException(400, "Empty summary")
|
raise HTTPException(400, "Empty summary")
|
||||||
|
|
||||||
@@ -257,14 +323,18 @@ async def store_chunk(req: ChunkStoreRequest):
|
|||||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
ts = req.original_ts or time.time()
|
ts = req.original_ts or time.time()
|
||||||
|
|
||||||
|
encrypted_text = _encrypt(req.chunk_text, req.user_id)
|
||||||
|
encrypted_summary = _encrypt(req.summary, req.user_id)
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
await _set_rls_user(conn, req.user_id)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO conversation_chunks
|
INSERT INTO conversation_chunks
|
||||||
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||||
""",
|
""",
|
||||||
req.user_id, req.room_id, req.chunk_text, req.summary,
|
req.user_id, req.room_id, encrypted_text, encrypted_summary,
|
||||||
req.source_event_id, ts, vec_literal,
|
req.source_event_id, ts, vec_literal,
|
||||||
)
|
)
|
||||||
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
|
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
|
||||||
@@ -294,7 +364,13 @@ async def query_chunks(req: ChunkQueryRequest):
|
|||||||
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||||
params.append(req.top_k)
|
params.append(req.top_k)
|
||||||
|
|
||||||
|
# Determine user_id for RLS and decryption
|
||||||
|
rls_user = req.user_id if req.user_id else ""
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
if rls_user:
|
||||||
|
await _set_rls_user(conn, rls_user)
|
||||||
|
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
f"""
|
f"""
|
||||||
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
|
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
|
||||||
@@ -309,8 +385,8 @@ async def query_chunks(req: ChunkQueryRequest):
|
|||||||
|
|
||||||
results = [
|
results = [
|
||||||
{
|
{
|
||||||
"chunk_text": r["chunk_text"],
|
"chunk_text": _decrypt(r["chunk_text"], r["user_id"]),
|
||||||
"summary": r["summary"],
|
"summary": _decrypt(r["summary"], r["user_id"]),
|
||||||
"room_id": r["room_id"],
|
"room_id": r["room_id"],
|
||||||
"user_id": r["user_id"],
|
"user_id": r["user_id"],
|
||||||
"original_ts": r["original_ts"],
|
"original_ts": r["original_ts"],
|
||||||
@@ -343,15 +419,18 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
|||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
for chunk, embedding in zip(batch, embeddings):
|
for chunk, embedding in zip(batch, embeddings):
|
||||||
|
await _set_rls_user(conn, chunk.user_id)
|
||||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
ts = chunk.original_ts or time.time()
|
ts = chunk.original_ts or time.time()
|
||||||
|
encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id)
|
||||||
|
encrypted_summary = _encrypt(chunk.summary, chunk.user_id)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO conversation_chunks
|
INSERT INTO conversation_chunks
|
||||||
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||||
""",
|
""",
|
||||||
chunk.user_id, chunk.room_id, chunk.chunk_text, chunk.summary,
|
chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary,
|
||||||
chunk.source_event_id, ts, vec_literal,
|
chunk.source_event_id, ts, vec_literal,
|
||||||
)
|
)
|
||||||
stored += 1
|
stored += 1
|
||||||
@@ -364,6 +443,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
|||||||
async def count_user_chunks(user_id: str):
|
async def count_user_chunks(user_id: str):
|
||||||
"""Count conversation chunks for a user."""
|
"""Count conversation chunks for a user."""
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
|
await _set_rls_user(conn, user_id)
|
||||||
count = await conn.fetchval(
|
count = await conn.fetchval(
|
||||||
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
|
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
|
||||||
)
|
)
|
||||||
|
|||||||
95
memory-service/migrate_encrypt.py
Normal file
95
memory-service/migrate_encrypt.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""MAT-107: One-time migration to encrypt existing plaintext memory data.
|
||||||
|
|
||||||
|
Run INSIDE the memory-service container after deploying new code:
|
||||||
|
docker exec -it matrix-ai-agent-memory-service-1 python migrate_encrypt.py
|
||||||
|
|
||||||
|
Connects as owner (memory) to bypass RLS.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
OWNER_DSN = os.environ.get(
|
||||||
|
"OWNER_DATABASE_URL",
|
||||||
|
"postgresql://memory:{password}@memory-db:5432/memories".format(
|
||||||
|
password=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_user_key(user_id: str) -> bytes:
|
||||||
|
derived = hashlib.pbkdf2_hmac("sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1)
|
||||||
|
return base64.urlsafe_b64encode(derived)
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(text: str, user_id: str) -> str:
|
||||||
|
f = Fernet(_derive_user_key(user_id))
|
||||||
|
return f.encrypt(text.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_encrypted(text: str, user_id: str) -> bool:
|
||||||
|
"""Check if text is already Fernet-encrypted."""
|
||||||
|
try:
|
||||||
|
f = Fernet(_derive_user_key(user_id))
|
||||||
|
f.decrypt(text.encode())
|
||||||
|
return True
|
||||||
|
except (InvalidToken, Exception):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate():
|
||||||
|
if not ENCRYPTION_KEY:
|
||||||
|
print("ERROR: MEMORY_ENCRYPTION_KEY not set")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
conn = await asyncpg.connect(OWNER_DSN)
|
||||||
|
|
||||||
|
# Migrate memories
|
||||||
|
rows = await conn.fetch("SELECT id, user_id, fact FROM memories ORDER BY id")
|
||||||
|
print(f"Migrating {len(rows)} memories...")
|
||||||
|
encrypted = 0
|
||||||
|
skipped = 0
|
||||||
|
for row in rows:
|
||||||
|
if _is_encrypted(row["fact"], row["user_id"]):
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
enc_fact = _encrypt(row["fact"], row["user_id"])
|
||||||
|
await conn.execute("UPDATE memories SET fact = $1 WHERE id = $2", enc_fact, row["id"])
|
||||||
|
encrypted += 1
|
||||||
|
if encrypted % 100 == 0:
|
||||||
|
print(f" memories: {encrypted}/{len(rows)} encrypted")
|
||||||
|
print(f"Memories done: {encrypted} encrypted, {skipped} already encrypted")
|
||||||
|
|
||||||
|
# Migrate conversation_chunks
|
||||||
|
rows = await conn.fetch("SELECT id, user_id, chunk_text, summary FROM conversation_chunks ORDER BY id")
|
||||||
|
print(f"Migrating {len(rows)} chunks...")
|
||||||
|
encrypted = 0
|
||||||
|
skipped = 0
|
||||||
|
for row in rows:
|
||||||
|
if _is_encrypted(row["chunk_text"], row["user_id"]):
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
enc_text = _encrypt(row["chunk_text"], row["user_id"])
|
||||||
|
enc_summary = _encrypt(row["summary"], row["user_id"])
|
||||||
|
await conn.execute(
|
||||||
|
"UPDATE conversation_chunks SET chunk_text = $1, summary = $2 WHERE id = $3",
|
||||||
|
enc_text, enc_summary, row["id"],
|
||||||
|
)
|
||||||
|
encrypted += 1
|
||||||
|
if encrypted % 500 == 0:
|
||||||
|
print(f" chunks: {encrypted}/{len(rows)} encrypted")
|
||||||
|
print(f"Chunks done: {encrypted} encrypted, {skipped} already encrypted")
|
||||||
|
|
||||||
|
await conn.close()
|
||||||
|
print("Migration complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(migrate())
|
||||||
@@ -3,3 +3,4 @@ uvicorn>=0.34,<1.0
|
|||||||
asyncpg>=0.30,<1.0
|
asyncpg>=0.30,<1.0
|
||||||
pgvector>=0.3,<1.0
|
pgvector>=0.3,<1.0
|
||||||
httpx>=0.27,<1.0
|
httpx>=0.27,<1.0
|
||||||
|
cryptography>=44.0,<45.0
|
||||||
|
|||||||
40
migrate_rls.sql
Normal file
40
migrate_rls.sql
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
-- MAT-107: Row-Level Security for memory tables
|
||||||
|
-- Run as superuser (memory) which owns the tables
|
||||||
|
|
||||||
|
-- Create restricted role for memory-service
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'memory_app') THEN
|
||||||
|
CREATE ROLE memory_app LOGIN PASSWORD 'OhugBZP4g4d7rk3OszOq1Xe3yo7hQwEn';
|
||||||
|
END IF;
|
||||||
|
END
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- Grant permissions to memory_app
|
||||||
|
GRANT CONNECT ON DATABASE memories TO memory_app;
|
||||||
|
GRANT USAGE ON SCHEMA public TO memory_app;
|
||||||
|
GRANT SELECT, INSERT, DELETE ON memories, conversation_chunks TO memory_app;
|
||||||
|
GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO memory_app;
|
||||||
|
|
||||||
|
-- Ensure tables are owned by memory (superuser) so RLS doesn't apply to owner
|
||||||
|
ALTER TABLE memories OWNER TO memory;
|
||||||
|
ALTER TABLE conversation_chunks OWNER TO memory;
|
||||||
|
|
||||||
|
-- Enable RLS
|
||||||
|
ALTER TABLE memories ENABLE ROW LEVEL SECURITY;
|
||||||
|
ALTER TABLE conversation_chunks ENABLE ROW LEVEL SECURITY;
|
||||||
|
|
||||||
|
-- Drop existing policies if re-running
|
||||||
|
DROP POLICY IF EXISTS user_isolation_memories ON memories;
|
||||||
|
DROP POLICY IF EXISTS user_isolation_chunks ON conversation_chunks;
|
||||||
|
|
||||||
|
-- RLS policies: rows visible only when session var matches user_id
|
||||||
|
-- current_setting with missing_ok=true returns empty string if not set
|
||||||
|
CREATE POLICY user_isolation_memories ON memories
|
||||||
|
USING (user_id = current_setting('app.current_user_id', true));
|
||||||
|
|
||||||
|
CREATE POLICY user_isolation_chunks ON conversation_chunks
|
||||||
|
USING (user_id = current_setting('app.current_user_id', true));
|
||||||
|
|
||||||
|
-- Verify
|
||||||
|
SELECT tablename, rowsecurity FROM pg_tables WHERE tablename IN ('memories', 'conversation_chunks');
|
||||||
21
setup_ssl.sh
Normal file
21
setup_ssl.sh
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# MAT-107: Generate self-signed SSL cert for memory-db and configure postgres
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SSL_DIR="/opt/matrix-ai-agent/memory-db-ssl"
|
||||||
|
mkdir -p "$SSL_DIR"
|
||||||
|
|
||||||
|
# Generate self-signed cert (valid 10 years)
|
||||||
|
openssl req -new -x509 -days 3650 -nodes \
|
||||||
|
-subj "/CN=memory-db" \
|
||||||
|
-keyout "$SSL_DIR/server.key" \
|
||||||
|
-out "$SSL_DIR/server.crt" \
|
||||||
|
2>/dev/null
|
||||||
|
|
||||||
|
# Postgres requires specific permissions
|
||||||
|
chmod 600 "$SSL_DIR/server.key"
|
||||||
|
chmod 644 "$SSL_DIR/server.crt"
|
||||||
|
# Postgres runs as uid 999 in the pgvector container
|
||||||
|
chown 999:999 "$SSL_DIR/server.key" "$SSL_DIR/server.crt"
|
||||||
|
|
||||||
|
echo "SSL certs generated in $SSL_DIR"
|
||||||
Reference in New Issue
Block a user