diff --git a/bot.py b/bot.py index 0b0d773..b999281 100644 --- a/bot.py +++ b/bot.py @@ -70,7 +70,7 @@ HELP_TEXT = """**AI Bot Commands** - `!ai models` — List available models - `!ai set-model ` — Set model for this room - `!ai search ` — Search documents (WildFiles) -- `!ai connect ` — Connect your WildFiles account (DM only) +- `!ai connect` — Connect your WildFiles account (opens browser approval) - `!ai disconnect` — Disconnect your WildFiles account - `!ai auto-rename on|off` — Auto-rename room based on conversation topic - **@mention the bot** or start with `!ai` for a regular AI response""" @@ -181,6 +181,7 @@ class Bot: self._loaded_rooms: set[str] = set() # rooms where we've loaded state self._sync_token_received = False self._verifications: dict[str, dict] = {} # txn_id -> verification state + self._pending_connects: dict[str, str] = {} # matrix_user_id -> device_code @staticmethod def _load_user_keys() -> dict[str, str]: @@ -419,8 +420,9 @@ class Bot: if cmd == "help": await self._send_text(room.room_id, HELP_TEXT) - elif cmd.startswith("connect "): - await self._handle_connect(room, cmd[8:].strip(), event) + elif cmd == "connect" or cmd.startswith("connect "): + api_key = cmd[8:].strip() if cmd.startswith("connect ") else "" + await self._handle_connect(room, api_key, event) elif cmd == "disconnect": await self._handle_disconnect(room, event) @@ -501,41 +503,69 @@ class Bot: await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}") async def _handle_connect(self, room, api_key: str, event=None): - """Handle !ai connect — validate and store user's WildFiles API key.""" + """Handle !ai connect — SSO device flow, or !ai connect as fallback.""" sender = event.sender if event else None - if not api_key: - await self._send_text(room.room_id, "Usage: `!ai connect `") - return - - # Redact the message containing the API key for security - if event: - try: - await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security") - except Exception: - logger.debug("Could not redact connect message", exc_info=True) if not self.rag.base_url: await self._send_text(room.room_id, "WildFiles is not configured.") return - # Validate the key - stats = await self.rag.validate_key(api_key) - if stats is None: - await self._send_text(room.room_id, "Invalid API key. Please check and try again.") + # Fallback: direct API key provided + if api_key: + # Redact the message containing the API key for security + if event: + try: + await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security") + except Exception: + logger.debug("Could not redact connect message", exc_info=True) + + stats = await self.rag.validate_key(api_key) + if stats is None: + await self._send_text(room.room_id, "Invalid API key. Please check and try again.") + return + + self.user_keys[sender] = api_key + self._save_user_keys() + org_name = stats.get("organization", "unknown") + total = stats.get("total_documents", 0) + await self._send_text( + room.room_id, + f"Connected to WildFiles (org: **{org_name}**, {total} documents). " + f"Your documents are now searchable.", + ) + logger.info("User %s connected WildFiles key (org: %s)", sender, org_name) return - # Store the key - self.user_keys[sender] = api_key - self._save_user_keys() + # SSO device authorization flow + if sender and sender in self._pending_connects: + await self._send_text(room.room_id, "A connect flow is already in progress. Please complete or wait for it to expire.") + return + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{self.rag.base_url}/api/v1/auth/device/code") + resp.raise_for_status() + data = resp.json() + except Exception: + logger.exception("Failed to start device auth flow") + await self._send_text(room.room_id, "Failed to start connection flow. Please try again later.") + return + + device_code = data["device_code"] + user_code = data["user_code"] + verification_url = data["verification_url"] - org_name = stats.get("organization", "unknown") - total = stats.get("total_documents", 0) await self._send_text( room.room_id, - f"Connected to WildFiles (org: **{org_name}**, {total} documents). " - f"Your documents are now searchable.", + f"To connect WildFiles, visit:\n\n" + f"**{verification_url}**\n\n" + f"and enter code: **{user_code}**\n\n" + f"_This link expires in 10 minutes._", ) - logger.info("User %s connected WildFiles key (org: %s)", sender, org_name) + + # Track pending connect and start polling + self._pending_connects[sender] = device_code + asyncio.create_task(self._poll_device_auth(room.room_id, sender, device_code)) async def _handle_disconnect(self, room, event=None): """Handle !ai disconnect — remove stored WildFiles API key.""" @@ -548,6 +578,46 @@ class Bot: else: await self._send_text(room.room_id, "No WildFiles account connected.") + async def _poll_device_auth(self, room_id: str, sender: str, device_code: str): + """Poll WildFiles for device auth approval (5s interval, 10 min max).""" + poll_url = f"{self.rag.base_url}/api/v1/auth/device/status" + try: + for _ in range(120): # 120 * 5s = 10 min + await asyncio.sleep(5) + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(poll_url, params={"device_code": device_code}) + resp.raise_for_status() + data = resp.json() + except Exception: + logger.debug("Device auth poll failed, retrying", exc_info=True) + continue + + if data["status"] == "approved": + api_key = data["api_key"] + org_slug = data.get("organization", "unknown") + self.user_keys[sender] = api_key + self._save_user_keys() + await self._send_text( + room_id, + f"Connected to WildFiles (org: **{org_slug}**). Your documents are now searchable.", + ) + logger.info("User %s connected via device auth (org: %s)", sender, org_slug) + return + elif data["status"] == "expired": + await self._send_text(room_id, "Connection flow expired. Type `!ai connect` to try again.") + return + + # Timeout after 10 minutes + await self._send_text(room_id, "Connection flow timed out. Type `!ai connect` to try again.") + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Device auth polling error") + await self._send_text(room_id, "Connection flow failed. Type `!ai connect` to try again.") + finally: + self._pending_connects.pop(sender, None) + async def _respond_with_ai(self, room, user_message: str, sender: str = None): model = self.room_models.get(room.room_id, DEFAULT_MODEL)