feat: Add SSO device auth flow for !ai connect (WF-90)
!ai connect (no args) now starts a browser-based device authorization flow instead of requiring a raw API key. Direct key input preserved as fallback. Bot polls WildFiles for approval with 5s interval. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
122
bot.py
122
bot.py
@@ -70,7 +70,7 @@ HELP_TEXT = """**AI Bot Commands**
|
||||
- `!ai models` — List available models
|
||||
- `!ai set-model <model>` — Set model for this room
|
||||
- `!ai search <query>` — Search documents (WildFiles)
|
||||
- `!ai connect <api-key>` — Connect your WildFiles account (DM only)
|
||||
- `!ai connect` — Connect your WildFiles account (opens browser approval)
|
||||
- `!ai disconnect` — Disconnect your WildFiles account
|
||||
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
|
||||
- **@mention the bot** or start with `!ai` for a regular AI response"""
|
||||
@@ -181,6 +181,7 @@ class Bot:
|
||||
self._loaded_rooms: set[str] = set() # rooms where we've loaded state
|
||||
self._sync_token_received = False
|
||||
self._verifications: dict[str, dict] = {} # txn_id -> verification state
|
||||
self._pending_connects: dict[str, str] = {} # matrix_user_id -> device_code
|
||||
|
||||
@staticmethod
|
||||
def _load_user_keys() -> dict[str, str]:
|
||||
@@ -419,8 +420,9 @@ class Bot:
|
||||
if cmd == "help":
|
||||
await self._send_text(room.room_id, HELP_TEXT)
|
||||
|
||||
elif cmd.startswith("connect "):
|
||||
await self._handle_connect(room, cmd[8:].strip(), event)
|
||||
elif cmd == "connect" or cmd.startswith("connect "):
|
||||
api_key = cmd[8:].strip() if cmd.startswith("connect ") else ""
|
||||
await self._handle_connect(room, api_key, event)
|
||||
|
||||
elif cmd == "disconnect":
|
||||
await self._handle_disconnect(room, event)
|
||||
@@ -501,41 +503,69 @@ class Bot:
|
||||
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
|
||||
|
||||
async def _handle_connect(self, room, api_key: str, event=None):
|
||||
"""Handle !ai connect <api-key> — validate and store user's WildFiles API key."""
|
||||
"""Handle !ai connect — SSO device flow, or !ai connect <key> as fallback."""
|
||||
sender = event.sender if event else None
|
||||
if not api_key:
|
||||
await self._send_text(room.room_id, "Usage: `!ai connect <wildfiles-api-key>`")
|
||||
return
|
||||
|
||||
# Redact the message containing the API key for security
|
||||
if event:
|
||||
try:
|
||||
await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security")
|
||||
except Exception:
|
||||
logger.debug("Could not redact connect message", exc_info=True)
|
||||
|
||||
if not self.rag.base_url:
|
||||
await self._send_text(room.room_id, "WildFiles is not configured.")
|
||||
return
|
||||
|
||||
# Validate the key
|
||||
stats = await self.rag.validate_key(api_key)
|
||||
if stats is None:
|
||||
await self._send_text(room.room_id, "Invalid API key. Please check and try again.")
|
||||
# Fallback: direct API key provided
|
||||
if api_key:
|
||||
# Redact the message containing the API key for security
|
||||
if event:
|
||||
try:
|
||||
await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security")
|
||||
except Exception:
|
||||
logger.debug("Could not redact connect message", exc_info=True)
|
||||
|
||||
stats = await self.rag.validate_key(api_key)
|
||||
if stats is None:
|
||||
await self._send_text(room.room_id, "Invalid API key. Please check and try again.")
|
||||
return
|
||||
|
||||
self.user_keys[sender] = api_key
|
||||
self._save_user_keys()
|
||||
org_name = stats.get("organization", "unknown")
|
||||
total = stats.get("total_documents", 0)
|
||||
await self._send_text(
|
||||
room.room_id,
|
||||
f"Connected to WildFiles (org: **{org_name}**, {total} documents). "
|
||||
f"Your documents are now searchable.",
|
||||
)
|
||||
logger.info("User %s connected WildFiles key (org: %s)", sender, org_name)
|
||||
return
|
||||
|
||||
# Store the key
|
||||
self.user_keys[sender] = api_key
|
||||
self._save_user_keys()
|
||||
# SSO device authorization flow
|
||||
if sender and sender in self._pending_connects:
|
||||
await self._send_text(room.room_id, "A connect flow is already in progress. Please complete or wait for it to expire.")
|
||||
return
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(f"{self.rag.base_url}/api/v1/auth/device/code")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
logger.exception("Failed to start device auth flow")
|
||||
await self._send_text(room.room_id, "Failed to start connection flow. Please try again later.")
|
||||
return
|
||||
|
||||
device_code = data["device_code"]
|
||||
user_code = data["user_code"]
|
||||
verification_url = data["verification_url"]
|
||||
|
||||
org_name = stats.get("organization", "unknown")
|
||||
total = stats.get("total_documents", 0)
|
||||
await self._send_text(
|
||||
room.room_id,
|
||||
f"Connected to WildFiles (org: **{org_name}**, {total} documents). "
|
||||
f"Your documents are now searchable.",
|
||||
f"To connect WildFiles, visit:\n\n"
|
||||
f"**{verification_url}**\n\n"
|
||||
f"and enter code: **{user_code}**\n\n"
|
||||
f"_This link expires in 10 minutes._",
|
||||
)
|
||||
logger.info("User %s connected WildFiles key (org: %s)", sender, org_name)
|
||||
|
||||
# Track pending connect and start polling
|
||||
self._pending_connects[sender] = device_code
|
||||
asyncio.create_task(self._poll_device_auth(room.room_id, sender, device_code))
|
||||
|
||||
async def _handle_disconnect(self, room, event=None):
|
||||
"""Handle !ai disconnect — remove stored WildFiles API key."""
|
||||
@@ -548,6 +578,46 @@ class Bot:
|
||||
else:
|
||||
await self._send_text(room.room_id, "No WildFiles account connected.")
|
||||
|
||||
async def _poll_device_auth(self, room_id: str, sender: str, device_code: str):
|
||||
"""Poll WildFiles for device auth approval (5s interval, 10 min max)."""
|
||||
poll_url = f"{self.rag.base_url}/api/v1/auth/device/status"
|
||||
try:
|
||||
for _ in range(120): # 120 * 5s = 10 min
|
||||
await asyncio.sleep(5)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(poll_url, params={"device_code": device_code})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
logger.debug("Device auth poll failed, retrying", exc_info=True)
|
||||
continue
|
||||
|
||||
if data["status"] == "approved":
|
||||
api_key = data["api_key"]
|
||||
org_slug = data.get("organization", "unknown")
|
||||
self.user_keys[sender] = api_key
|
||||
self._save_user_keys()
|
||||
await self._send_text(
|
||||
room_id,
|
||||
f"Connected to WildFiles (org: **{org_slug}**). Your documents are now searchable.",
|
||||
)
|
||||
logger.info("User %s connected via device auth (org: %s)", sender, org_slug)
|
||||
return
|
||||
elif data["status"] == "expired":
|
||||
await self._send_text(room_id, "Connection flow expired. Type `!ai connect` to try again.")
|
||||
return
|
||||
|
||||
# Timeout after 10 minutes
|
||||
await self._send_text(room_id, "Connection flow timed out. Type `!ai connect` to try again.")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Device auth polling error")
|
||||
await self._send_text(room_id, "Connection flow failed. Type `!ai connect` to try again.")
|
||||
finally:
|
||||
self._pending_connects.pop(sender, None)
|
||||
|
||||
async def _respond_with_ai(self, room, user_message: str, sender: str = None):
|
||||
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user