diff --git a/bot.py b/bot.py index aacc8b6..0b0d773 100644 --- a/bot.py +++ b/bot.py @@ -50,6 +50,7 @@ LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet") WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "") WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "") +USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json") SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room. Keep answers concise but thorough. Use markdown formatting when helpful. @@ -69,6 +70,8 @@ HELP_TEXT = """**AI Bot Commands** - `!ai models` — List available models - `!ai set-model ` — Set model for this room - `!ai search ` — Search documents (WildFiles) +- `!ai connect ` — Connect your WildFiles account (DM only) +- `!ai disconnect` — Disconnect your WildFiles account - `!ai auto-rename on|off` — Auto-rename room based on conversation topic - **@mention the bot** or start with `!ai` for a regular AI response""" @@ -81,14 +84,21 @@ class DocumentRAG: self.org = org self.enabled = bool(base_url and org) - async def search(self, query: str, top_k: int = 3) -> list[dict]: - if not self.enabled: + async def search(self, query: str, top_k: int = 3, api_key: str | None = None) -> list[dict]: + if not api_key and not self.enabled: return [] try: + headers = {} + body = {"query": query, "limit": top_k} + if api_key: + headers["X-API-Key"] = api_key + else: + body["organization"] = self.org async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.post( f"{self.base_url}/api/v1/rag/search", - json={"query": query, "organization": self.org, "limit": top_k}, + json=body, + headers=headers, ) resp.raise_for_status() return resp.json().get("results", []) @@ -96,6 +106,24 @@ class DocumentRAG: logger.debug("WildFiles search failed", exc_info=True) return [] + async def validate_key(self, api_key: str) -> dict | None: + """Validate an API key against WildFiles. Returns stats dict or None.""" + if not self.base_url: + return None + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{self.base_url}/api/v1/rag/stats", + headers={"X-API-Key": api_key}, + ) + resp.raise_for_status() + data = resp.json() + if data.get("total_documents", 0) >= 0: + return data + except Exception: + logger.debug("WildFiles key validation failed", exc_info=True) + return None + def format_context(self, results: list[dict]) -> str: if not results: return "" @@ -146,6 +174,7 @@ class Bot: self.active_calls = set() # rooms where we've sent call member event self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG) self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None + self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key self.room_models: dict[str, str] = {} # room_id -> model name self.auto_rename_rooms: set[str] = set() # rooms with auto-rename enabled self.renamed_rooms: dict[str, float] = {} # room_id -> timestamp of last rename @@ -153,6 +182,24 @@ class Bot: self._sync_token_received = False self._verifications: dict[str, dict] = {} # txn_id -> verification state + @staticmethod + def _load_user_keys() -> dict[str, str]: + if os.path.exists(USER_KEYS_FILE): + try: + with open(USER_KEYS_FILE) as f: + return json.load(f) + except Exception: + logger.warning("Failed to load user keys file, starting fresh") + return {} + + def _save_user_keys(self): + try: + os.makedirs(os.path.dirname(USER_KEYS_FILE), exist_ok=True) + with open(USER_KEYS_FILE, "w") as f: + json.dump(self.user_keys, f) + except Exception: + logger.exception("Failed to save user keys") + async def start(self): # Restore existing session or create new one if os.path.exists(CREDS_FILE): @@ -340,7 +387,7 @@ class Bot: # Command handling if body.startswith("!ai "): cmd = body[4:].strip() - await self._handle_command(room, cmd) + await self._handle_command(room, cmd, event) return if body == "!ai": await self._send_text(room.room_id, HELP_TEXT) @@ -364,14 +411,20 @@ class Bot: await self.client.room_typing(room.room_id, typing_state=True) try: - await self._respond_with_ai(room, body) + await self._respond_with_ai(room, body, sender=event.sender) finally: await self.client.room_typing(room.room_id, typing_state=False) - async def _handle_command(self, room, cmd: str): + async def _handle_command(self, room, cmd: str, event=None): if cmd == "help": await self._send_text(room.room_id, HELP_TEXT) + elif cmd.startswith("connect "): + await self._handle_connect(room, cmd[8:].strip(), event) + + elif cmd == "disconnect": + await self._handle_disconnect(room, event) + elif cmd == "models": if not self.llm: await self._send_text(room.room_id, "LLM not configured.") @@ -427,7 +480,9 @@ class Bot: if not query: await self._send_text(room.room_id, "Usage: `!ai search `") return - results = await self.rag.search(query, top_k=5) + sender = event.sender if event else None + user_api_key = self.user_keys.get(sender) if sender else None + results = await self.rag.search(query, top_k=5, api_key=user_api_key) if not results: await self._send_text(room.room_id, "No documents found.") return @@ -436,15 +491,64 @@ class Bot: else: # Treat unknown commands as AI prompts if self.llm: + sender = event.sender if event else None await self.client.room_typing(room.room_id, typing_state=True) try: - await self._respond_with_ai(room, cmd) + await self._respond_with_ai(room, cmd, sender=sender) finally: await self.client.room_typing(room.room_id, typing_state=False) else: await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}") - async def _respond_with_ai(self, room, user_message: str): + async def _handle_connect(self, room, api_key: str, event=None): + """Handle !ai connect — validate and store user's WildFiles API key.""" + sender = event.sender if event else None + if not api_key: + await self._send_text(room.room_id, "Usage: `!ai connect `") + return + + # Redact the message containing the API key for security + if event: + try: + await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security") + except Exception: + logger.debug("Could not redact connect message", exc_info=True) + + if not self.rag.base_url: + await self._send_text(room.room_id, "WildFiles is not configured.") + return + + # Validate the key + stats = await self.rag.validate_key(api_key) + if stats is None: + await self._send_text(room.room_id, "Invalid API key. Please check and try again.") + return + + # Store the key + self.user_keys[sender] = api_key + self._save_user_keys() + + org_name = stats.get("organization", "unknown") + total = stats.get("total_documents", 0) + await self._send_text( + room.room_id, + f"Connected to WildFiles (org: **{org_name}**, {total} documents). " + f"Your documents are now searchable.", + ) + logger.info("User %s connected WildFiles key (org: %s)", sender, org_name) + + async def _handle_disconnect(self, room, event=None): + """Handle !ai disconnect — remove stored WildFiles API key.""" + sender = event.sender if event else None + if sender and sender in self.user_keys: + del self.user_keys[sender] + self._save_user_keys() + await self._send_text(room.room_id, "Disconnected from WildFiles. Using default search.") + logger.info("User %s disconnected WildFiles key", sender) + else: + await self._send_text(room.room_id, "No WildFiles account connected.") + + async def _respond_with_ai(self, room, user_message: str, sender: str = None): model = self.room_models.get(room.room_id, DEFAULT_MODEL) # Fetch conversation history FIRST (needed for query rewriting) @@ -465,8 +569,9 @@ class Bot: # Rewrite query using conversation context for better RAG search search_query = await self._rewrite_query(user_message, history, model) - # WildFiles document context - doc_results = await self.rag.search(search_query) + # WildFiles document context (use per-user API key if available) + user_api_key = self.user_keys.get(sender) if sender else None + doc_results = await self.rag.search(search_query, api_key=user_api_key) doc_context = self.rag.format_context(doc_results) if doc_context: logger.info("RAG found %d docs for: %s (original: %s)", len(doc_results), search_query[:50], user_message[:50])