From 9578e0406b168c33a8726325da93e189b9f30796 Mon Sep 17 00:00:00 2001 From: Christian Gick Date: Tue, 3 Mar 2026 11:19:02 +0000 Subject: [PATCH] feat: Matrix E2EE key management + multi-user isolation - Add rag_key_manager.py: stores encryption key in private E2EE room - Bot loads key from Matrix on startup, injects into RAG via portal proxy - No plaintext key on disk (removed RAG_ENCRYPTION_KEY from .env) - Pass owner_id (matrix_user_id) to RAG search for user isolation - Stronger format_context instructions for source link rendering Co-Authored-By: Claude Opus 4.6 --- bot.py | 34 +++++++-- rag_key_manager.py | 186 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 5 deletions(-) create mode 100644 rag_key_manager.py diff --git a/bot.py b/bot.py index 907a692..c74622f 100644 --- a/bot.py +++ b/bot.py @@ -15,6 +15,7 @@ import fitz # pymupdf import httpx from openai import AsyncOpenAI from olm import sas as olm_sas +from rag_key_manager import RAGKeyManager from nio import ( AsyncClient, @@ -335,15 +336,17 @@ class DocumentRAG: # Prefer customer-VM RAG service (encrypted, local) if self.use_local_rag: - return await self._search_local(query, top_k) + return await self._search_local(query, top_k, matrix_user_id) # Fallback: central portal API (legacy, unencrypted) return await self._search_portal(query, top_k, matrix_user_id) - async def _search_local(self, query: str, top_k: int) -> list[dict]: + async def _search_local(self, query: str, top_k: int, matrix_user_id: str | None = None) -> list[dict]: """Search via customer-VM RAG service (localhost).""" try: body = {"query": query, "limit": top_k} + if matrix_user_id: + body["owner_id"] = matrix_user_id headers: dict[str, str] = {"Content-Type": "application/json"} if self.rag_auth_token: headers["Authorization"] = f"Bearer {self.rag_auth_token}" @@ -415,9 +418,13 @@ class DocumentRAG: parts.append(f"Content:\n{content}") parts.append("") # blank line between docs - parts.append("Use the document content above to answer the user's question. " - "When referencing documents, use markdown links: [Document Title](url). " - "Never show raw URLs.") + parts.append("IMPORTANT INSTRUCTIONS FOR DOCUMENT RESPONSES:\n" + "1. Answer the user's question using the document content above.\n" + "2. You MUST include a source link for EVERY document you reference.\n" + "3. Format links as markdown: [Document Title](url)\n" + "4. Place the link right after mentioning or quoting the document.\n" + "5. If a document has no link, skip the link but still reference the title.\n" + "6. Never show raw URLs without markdown formatting.") return "\n".join(parts) @@ -940,6 +947,7 @@ class Bot: self.active_callers: dict[str, set[str]] = {} # room_id → set of caller user IDs self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY, rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN) + self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY) self.memory = MemoryClient(MEMORY_SERVICE_URL) self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY) self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None @@ -1035,6 +1043,20 @@ class Bot: await self.client.sync_forever(timeout=30000, full_state=True) + async def _inject_rag_key(self): + """Load document encryption key from Matrix and inject into RAG service.""" + try: + seed_key = os.environ.get("RAG_ENCRYPTION_KEY_SEED") + success = await self.key_manager.ensure_rag_key(seed_key_hex=seed_key) + if success: + logger.info("RAG encryption key loaded from Matrix E2EE") + if seed_key: + logger.info("Migration complete - RAG_ENCRYPTION_KEY_SEED can now be removed from env") + else: + logger.warning("Failed to load RAG encryption key - documents will be inaccessible") + except Exception as e: + logger.error("RAG key injection failed: %s", e, exc_info=True) + async def on_invite(self, room, event: InviteMemberEvent): if event.state_key != BOT_USER: return @@ -1046,6 +1068,8 @@ class Bot: if not self._sync_token_received: self._sync_token_received = True logger.info("Initial sync complete, text handler active") + # Inject RAG encryption key from Matrix E2EE room + asyncio.create_task(self._inject_rag_key()) for user_id in list(self.client.device_store.users): for device in self.client.device_store.active_user_devices(user_id): if not device.verified: diff --git a/rag_key_manager.py b/rag_key_manager.py new file mode 100644 index 0000000..cc7bfba --- /dev/null +++ b/rag_key_manager.py @@ -0,0 +1,186 @@ +""" +RAG Document Key Manager — stores per-user encryption keys in Matrix E2EE rooms. + +The key is stored as an encrypted event in a private room that only the bot can access. +On startup, the bot syncs the room and re-injects the key into the RAG service +via the portal proxy (since RAG service is localhost-only on the customer VM). +No plaintext keys are ever written to disk. +""" + +import secrets +import logging +import httpx +from nio.api import RoomVisibility + +logger = logging.getLogger("matrix-ai-bot") + +KEY_EVENT_TYPE = "eu.matrixhost.rag_document_key" +KEY_ROOM_TOPIC = "RAG Document Encryption Keys \u2014 DO NOT LEAVE" + + +class RAGKeyManager: + """Manages per-user document encryption keys via Matrix E2EE.""" + + def __init__(self, client, portal_url: str, bot_api_key: str): + self.client = client + self.portal_url = portal_url.rstrip("/") if portal_url else "" + self.bot_api_key = bot_api_key + self._key_room_id: str | None = None + + async def ensure_rag_key(self, seed_key_hex: str | None = None) -> bool: + """Ensure RAG service has encryption key loaded. + + Args: + seed_key_hex: Existing key to migrate into Matrix storage (one-time). + """ + if not self.portal_url: + logger.warning("[rag-key] No portal URL configured") + return False + + # Check if RAG already has a key + if await self._rag_has_key(): + logger.info("[rag-key] RAG service already has key loaded") + room_id = await self._find_or_create_key_room() + if room_id: + existing = await self._load_key_from_room(room_id) + if not existing and seed_key_hex: + await self._store_key_in_room(room_id, seed_key_hex) + logger.info("[rag-key] Migrated existing key into Matrix E2EE room") + return True + + # Find or create the key storage room + room_id = await self._find_or_create_key_room() + if not room_id: + logger.error("[rag-key] Failed to find or create key room") + return False + + # Try to load existing key from room + key_hex = await self._load_key_from_room(room_id) + + if key_hex: + logger.info("[rag-key] Loaded existing key from Matrix room") + elif seed_key_hex: + key_hex = seed_key_hex + stored = await self._store_key_in_room(room_id, key_hex) + if not stored: + logger.error("[rag-key] Failed to store seed key in Matrix room") + return False + logger.info("[rag-key] Stored migration seed key in Matrix E2EE room") + else: + key_hex = secrets.token_hex(32) + stored = await self._store_key_in_room(room_id, key_hex) + if not stored: + logger.error("[rag-key] Failed to store new key in Matrix room") + return False + logger.info("[rag-key] Generated and stored new encryption key") + + return await self._inject_key(key_hex) + + async def _rag_has_key(self) -> bool: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get( + f"{self.portal_url}/api/bot/rag-key", + headers={"Authorization": f"Bearer {self.bot_api_key}"}, + ) + resp.raise_for_status() + return resp.json().get("has_key", False) + except Exception as e: + logger.debug("[rag-key] Health check failed: %s", e) + return False + + async def _find_or_create_key_room(self) -> str | None: + for room_id, room in self.client.rooms.items(): + if room.topic == KEY_ROOM_TOPIC: + self._key_room_id = room_id + logger.info("[rag-key] Found existing key room: %s", room_id) + return room_id + + try: + from nio import EnableEncryptionBuilder + initial_state = [EnableEncryptionBuilder().as_dict()] + except ImportError: + initial_state = [{ + "type": "m.room.encryption", + "state_key": "", + "content": {"algorithm": "m.megolm.v1.aes-sha2"}, + }] + + resp = await self.client.room_create( + name="RAG Key Storage", + topic=KEY_ROOM_TOPIC, + invite=[], + initial_state=initial_state, + visibility=RoomVisibility.private, + ) + + if hasattr(resp, "room_id"): + self._key_room_id = resp.room_id + logger.info("[rag-key] Created new key room: %s", resp.room_id) + return resp.room_id + + logger.error("[rag-key] Failed to create key room: %s", resp) + return None + + async def _load_key_from_room(self, room_id: str) -> str | None: + try: + resp = await self.client.room_messages( + room_id, start="", limit=50, direction="b", + ) + if not hasattr(resp, "chunk"): + return None + + for event in resp.chunk: + if hasattr(event, "source"): + source = event.source + if source.get("type") == KEY_EVENT_TYPE: + key = source.get("content", {}).get("key_hex") + if key: + return key + + if hasattr(event, "type") and event.type == KEY_EVENT_TYPE: + if hasattr(event, "content"): + key = event.content.get("key_hex") + if key: + return key + + return None + except Exception as e: + logger.warning("[rag-key] Failed to load key from room: %s", e) + return None + + async def _store_key_in_room(self, room_id: str, key_hex: str) -> bool: + try: + content = { + "key_hex": key_hex, + "algorithm": "aes-256-gcm", + "purpose": "rag-document-encryption", + "msgtype": "eu.matrixhost.rag_key", + } + resp = await self.client.room_send( + room_id, message_type=KEY_EVENT_TYPE, + content=content, ignore_unverified_devices=True, + ) + if hasattr(resp, "event_id"): + logger.info("[rag-key] Key stored as event %s", resp.event_id) + return True + logger.error("[rag-key] Failed to send key event: %s", resp) + return False + except Exception as e: + logger.error("[rag-key] Failed to store key: %s", e) + return False + + async def _inject_key(self, key_hex: str) -> bool: + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{self.portal_url}/api/bot/rag-key", + json={"key_hex": key_hex}, + headers={"Authorization": f"Bearer {self.bot_api_key}"}, + ) + resp.raise_for_status() + logger.info("[rag-key] Key injected into RAG service via portal proxy") + return True + except Exception as e: + logger.error("[rag-key] Failed to inject key: %s", e) + return False