feat: Add SSO device auth flow for !ai connect (WF-90)
!ai connect (no args) now starts a browser-based device authorization flow instead of requiring a raw API key. Direct key input preserved as fallback. Bot polls WildFiles for approval with 5s interval. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
122
bot.py
122
bot.py
@@ -70,7 +70,7 @@ HELP_TEXT = """**AI Bot Commands**
|
|||||||
- `!ai models` — List available models
|
- `!ai models` — List available models
|
||||||
- `!ai set-model <model>` — Set model for this room
|
- `!ai set-model <model>` — Set model for this room
|
||||||
- `!ai search <query>` — Search documents (WildFiles)
|
- `!ai search <query>` — Search documents (WildFiles)
|
||||||
- `!ai connect <api-key>` — Connect your WildFiles account (DM only)
|
- `!ai connect` — Connect your WildFiles account (opens browser approval)
|
||||||
- `!ai disconnect` — Disconnect your WildFiles account
|
- `!ai disconnect` — Disconnect your WildFiles account
|
||||||
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
|
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
|
||||||
- **@mention the bot** or start with `!ai` for a regular AI response"""
|
- **@mention the bot** or start with `!ai` for a regular AI response"""
|
||||||
@@ -181,6 +181,7 @@ class Bot:
|
|||||||
self._loaded_rooms: set[str] = set() # rooms where we've loaded state
|
self._loaded_rooms: set[str] = set() # rooms where we've loaded state
|
||||||
self._sync_token_received = False
|
self._sync_token_received = False
|
||||||
self._verifications: dict[str, dict] = {} # txn_id -> verification state
|
self._verifications: dict[str, dict] = {} # txn_id -> verification state
|
||||||
|
self._pending_connects: dict[str, str] = {} # matrix_user_id -> device_code
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_user_keys() -> dict[str, str]:
|
def _load_user_keys() -> dict[str, str]:
|
||||||
@@ -419,8 +420,9 @@ class Bot:
|
|||||||
if cmd == "help":
|
if cmd == "help":
|
||||||
await self._send_text(room.room_id, HELP_TEXT)
|
await self._send_text(room.room_id, HELP_TEXT)
|
||||||
|
|
||||||
elif cmd.startswith("connect "):
|
elif cmd == "connect" or cmd.startswith("connect "):
|
||||||
await self._handle_connect(room, cmd[8:].strip(), event)
|
api_key = cmd[8:].strip() if cmd.startswith("connect ") else ""
|
||||||
|
await self._handle_connect(room, api_key, event)
|
||||||
|
|
||||||
elif cmd == "disconnect":
|
elif cmd == "disconnect":
|
||||||
await self._handle_disconnect(room, event)
|
await self._handle_disconnect(room, event)
|
||||||
@@ -501,41 +503,69 @@ class Bot:
|
|||||||
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
|
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
|
||||||
|
|
||||||
async def _handle_connect(self, room, api_key: str, event=None):
|
async def _handle_connect(self, room, api_key: str, event=None):
|
||||||
"""Handle !ai connect <api-key> — validate and store user's WildFiles API key."""
|
"""Handle !ai connect — SSO device flow, or !ai connect <key> as fallback."""
|
||||||
sender = event.sender if event else None
|
sender = event.sender if event else None
|
||||||
if not api_key:
|
|
||||||
await self._send_text(room.room_id, "Usage: `!ai connect <wildfiles-api-key>`")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Redact the message containing the API key for security
|
|
||||||
if event:
|
|
||||||
try:
|
|
||||||
await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security")
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not redact connect message", exc_info=True)
|
|
||||||
|
|
||||||
if not self.rag.base_url:
|
if not self.rag.base_url:
|
||||||
await self._send_text(room.room_id, "WildFiles is not configured.")
|
await self._send_text(room.room_id, "WildFiles is not configured.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate the key
|
# Fallback: direct API key provided
|
||||||
stats = await self.rag.validate_key(api_key)
|
if api_key:
|
||||||
if stats is None:
|
# Redact the message containing the API key for security
|
||||||
await self._send_text(room.room_id, "Invalid API key. Please check and try again.")
|
if event:
|
||||||
|
try:
|
||||||
|
await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not redact connect message", exc_info=True)
|
||||||
|
|
||||||
|
stats = await self.rag.validate_key(api_key)
|
||||||
|
if stats is None:
|
||||||
|
await self._send_text(room.room_id, "Invalid API key. Please check and try again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.user_keys[sender] = api_key
|
||||||
|
self._save_user_keys()
|
||||||
|
org_name = stats.get("organization", "unknown")
|
||||||
|
total = stats.get("total_documents", 0)
|
||||||
|
await self._send_text(
|
||||||
|
room.room_id,
|
||||||
|
f"Connected to WildFiles (org: **{org_name}**, {total} documents). "
|
||||||
|
f"Your documents are now searchable.",
|
||||||
|
)
|
||||||
|
logger.info("User %s connected WildFiles key (org: %s)", sender, org_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Store the key
|
# SSO device authorization flow
|
||||||
self.user_keys[sender] = api_key
|
if sender and sender in self._pending_connects:
|
||||||
self._save_user_keys()
|
await self._send_text(room.room_id, "A connect flow is already in progress. Please complete or wait for it to expire.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.post(f"{self.rag.base_url}/api/v1/auth/device/code")
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to start device auth flow")
|
||||||
|
await self._send_text(room.room_id, "Failed to start connection flow. Please try again later.")
|
||||||
|
return
|
||||||
|
|
||||||
|
device_code = data["device_code"]
|
||||||
|
user_code = data["user_code"]
|
||||||
|
verification_url = data["verification_url"]
|
||||||
|
|
||||||
org_name = stats.get("organization", "unknown")
|
|
||||||
total = stats.get("total_documents", 0)
|
|
||||||
await self._send_text(
|
await self._send_text(
|
||||||
room.room_id,
|
room.room_id,
|
||||||
f"Connected to WildFiles (org: **{org_name}**, {total} documents). "
|
f"To connect WildFiles, visit:\n\n"
|
||||||
f"Your documents are now searchable.",
|
f"**{verification_url}**\n\n"
|
||||||
|
f"and enter code: **{user_code}**\n\n"
|
||||||
|
f"_This link expires in 10 minutes._",
|
||||||
)
|
)
|
||||||
logger.info("User %s connected WildFiles key (org: %s)", sender, org_name)
|
|
||||||
|
# Track pending connect and start polling
|
||||||
|
self._pending_connects[sender] = device_code
|
||||||
|
asyncio.create_task(self._poll_device_auth(room.room_id, sender, device_code))
|
||||||
|
|
||||||
async def _handle_disconnect(self, room, event=None):
|
async def _handle_disconnect(self, room, event=None):
|
||||||
"""Handle !ai disconnect — remove stored WildFiles API key."""
|
"""Handle !ai disconnect — remove stored WildFiles API key."""
|
||||||
@@ -548,6 +578,46 @@ class Bot:
|
|||||||
else:
|
else:
|
||||||
await self._send_text(room.room_id, "No WildFiles account connected.")
|
await self._send_text(room.room_id, "No WildFiles account connected.")
|
||||||
|
|
||||||
|
async def _poll_device_auth(self, room_id: str, sender: str, device_code: str):
|
||||||
|
"""Poll WildFiles for device auth approval (5s interval, 10 min max)."""
|
||||||
|
poll_url = f"{self.rag.base_url}/api/v1/auth/device/status"
|
||||||
|
try:
|
||||||
|
for _ in range(120): # 120 * 5s = 10 min
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.get(poll_url, params={"device_code": device_code})
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Device auth poll failed, retrying", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if data["status"] == "approved":
|
||||||
|
api_key = data["api_key"]
|
||||||
|
org_slug = data.get("organization", "unknown")
|
||||||
|
self.user_keys[sender] = api_key
|
||||||
|
self._save_user_keys()
|
||||||
|
await self._send_text(
|
||||||
|
room_id,
|
||||||
|
f"Connected to WildFiles (org: **{org_slug}**). Your documents are now searchable.",
|
||||||
|
)
|
||||||
|
logger.info("User %s connected via device auth (org: %s)", sender, org_slug)
|
||||||
|
return
|
||||||
|
elif data["status"] == "expired":
|
||||||
|
await self._send_text(room_id, "Connection flow expired. Type `!ai connect` to try again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Timeout after 10 minutes
|
||||||
|
await self._send_text(room_id, "Connection flow timed out. Type `!ai connect` to try again.")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Device auth polling error")
|
||||||
|
await self._send_text(room_id, "Connection flow failed. Type `!ai connect` to try again.")
|
||||||
|
finally:
|
||||||
|
self._pending_connects.pop(sender, None)
|
||||||
|
|
||||||
async def _respond_with_ai(self, room, user_message: str, sender: str = None):
|
async def _respond_with_ai(self, room, user_message: str, sender: str = None):
|
||||||
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user