diff --git a/bot.py b/bot.py index 65955df..7512217 100644 --- a/bot.py +++ b/bot.py @@ -314,9 +314,9 @@ HELP_TEXT = """**AI Bot Commands** - `!ai help` — Show this help - `!ai models` — List available models - `!ai set-model ` — Set model for this room -- `!ai search ` — Search documents (WildFiles) -- `!ai wildfiles connect` — Connect your WildFiles account (opens browser approval) -- `!ai wildfiles disconnect` — Disconnect your WildFiles account +- `!ai search ` — Search your documents (auto-connected via MatrixHost) +- `!ai docs connect ` — Connect with a custom document API key (optional) +- `!ai docs disconnect` — Disconnect custom document API key - `!ai auto-rename on|off` — Auto-rename room based on conversation topic - `!ai forget` — Delete all memories the bot has about you - `!ai memories` — Show what the bot remembers about you @@ -330,14 +330,17 @@ class DocumentRAG: def __init__(self, base_url: str, org: str): self.base_url = base_url.rstrip("/") self.org = org - self.enabled = bool(base_url and org) + self.enabled = bool(base_url) - async def search(self, query: str, top_k: int = 3, api_key: str | None = None) -> list[dict]: - if not api_key: + async def search(self, query: str, top_k: int = 3, api_key: str | None = None, org_slug: str | None = None) -> list[dict]: + org = org_slug or self.org + if not org and not api_key: return [] try: - headers = {"X-API-Key": api_key} - body = {"query": query, "limit": top_k} + headers = {} + if api_key: + headers["X-API-Key"] = api_key + body = {"query": query, "limit": top_k, "organization": org} async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.post( f"{self.base_url}/api/v1/rag/search", @@ -368,6 +371,22 @@ class DocumentRAG: logger.debug("WildFiles key validation failed", exc_info=True) return None + async def get_org_stats(self, org_slug: str) -> dict | None: + """Get stats for an org by slug. Returns stats dict or None.""" + if not self.base_url: + return None + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{self.base_url}/api/v1/rag/stats", + params={"organization": org_slug}, + ) + resp.raise_for_status() + return resp.json() + except Exception: + logger.debug("WildFiles org stats failed for %s", org_slug, exc_info=True) + return None + def format_context(self, results: list[dict]) -> str: if not results: return "" @@ -888,7 +907,8 @@ class Bot: self.memory = MemoryClient(MEMORY_SERVICE_URL) self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY) self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None - self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key + self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key (legacy) + self._wildfiles_org_cache: dict[str, str | None] = {} # matrix_user_id -> org_slug (from portal) self.room_models: dict[str, str] = {} # room_id -> model name self.auto_rename_rooms: set[str] = set() # rooms with auto-rename enabled self._recent_images: dict[str, tuple[str, str, float]] = {} # room_id -> (b64, mime, timestamp) @@ -917,6 +937,39 @@ class Bot: except Exception: logger.exception("Failed to save user keys") + async def _get_wildfiles_org(self, matrix_user_id: str) -> str | None: + """Get user's WildFiles org slug via MatrixHost portal API. + + Auto-provisions a WildFiles org if the user has a MatrixHost account. + Falls back to legacy user_keys for backward compat. + Results are cached per session. + """ + if matrix_user_id in self._wildfiles_org_cache: + return self._wildfiles_org_cache[matrix_user_id] + + # Try portal API (auto-provisions org if needed) + if self.atlassian.enabled: # reuses same portal_url + bot_api_key + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{self.atlassian.portal_url}/api/bot/tokens", + params={"matrix_user_id": matrix_user_id, "provider": "wildfiles"}, + headers={"Authorization": f"Bearer {self.atlassian.bot_api_key}"}, + ) + resp.raise_for_status() + data = resp.json() + if data.get("connected"): + org_slug = data["org_slug"] + self._wildfiles_org_cache[matrix_user_id] = org_slug + logger.debug("Resolved WildFiles org %s for %s via portal", org_slug, matrix_user_id) + return org_slug + except Exception: + logger.debug("Portal WildFiles org lookup failed for %s", matrix_user_id, exc_info=True) + + # No portal result — cache as None to avoid repeated lookups + self._wildfiles_org_cache[matrix_user_id] = None + return None + async def start(self): # Restore existing session or create new one if os.path.exists(CREDS_FILE): @@ -1813,11 +1866,14 @@ class Bot: if cmd == "help": await self._send_text(room.room_id, HELP_TEXT) - elif cmd == "wildfiles connect" or cmd.startswith("wildfiles connect "): - api_key = cmd[18:].strip() if cmd.startswith("wildfiles connect ") else "" + elif cmd == "wildfiles connect" or cmd.startswith("wildfiles connect ") or cmd == "docs connect" or cmd.startswith("docs connect "): + if cmd.startswith("docs connect"): + api_key = cmd[12:].strip() if cmd.startswith("docs connect ") else "" + else: + api_key = cmd[18:].strip() if cmd.startswith("wildfiles connect ") else "" await self._handle_connect(room, api_key, event) - elif cmd == "wildfiles disconnect": + elif cmd == "wildfiles disconnect" or cmd == "docs disconnect": await self._handle_disconnect(room, event) elif cmd == "models": @@ -1898,10 +1954,11 @@ class Bot: return sender = event.sender if event else None user_api_key = self.user_keys.get(sender) if sender else None - if not user_api_key: - await self._send_text(room.room_id, "WildFiles not connected. Use `!ai wildfiles connect` first.") + user_org_slug = await self._get_wildfiles_org(sender) if sender else None + if not user_api_key and not user_org_slug: + await self._send_text(room.room_id, "Documents not available. Manage your documents at [matrixhost.eu/documents](https://matrixhost.eu/documents).") return - results = await self.rag.search(query, top_k=5, api_key=user_api_key) + results = await self.rag.search(query, top_k=5, api_key=user_api_key, org_slug=user_org_slug) if not results: await self._send_text(room.room_id, "No documents found.") return @@ -1924,7 +1981,7 @@ class Bot: sender = event.sender if event else None if not self.rag.base_url: - await self._send_text(room.room_id, "WildFiles is not configured.") + await self._send_text(room.room_id, "Document search is not configured.") return # Fallback: direct API key provided @@ -1947,13 +2004,26 @@ class Bot: total = stats.get("total_documents", 0) await self._send_text( room.room_id, - f"Connected to WildFiles (org: **{org_name}**, {total} documents). " + f"Documents connected (org: **{org_name}**, {total} documents). " f"Your documents are now searchable.", ) logger.info("User %s connected WildFiles key (org: %s)", sender, org_name) return - # SSO device authorization flow + # Check if user already has auto-provisioned org via MatrixHost portal + if sender: + org_slug = await self._get_wildfiles_org(sender) + if org_slug: + stats = await self.rag.get_org_stats(org_slug) + total = stats.get("total_documents", 0) if stats else 0 + await self._send_text( + room.room_id, + f"Documents are already connected via your MatrixHost account (org: **{org_slug}**, {total} documents). " + f"Manage documents at [matrixhost.eu/documents](https://matrixhost.eu/documents).", + ) + return + + # SSO device authorization flow (fallback for non-MatrixHost users) if sender and sender in self._pending_connects: await self._send_text(room.room_id, "A connect flow is already in progress. Please complete or wait for it to expire.") return @@ -1974,7 +2044,7 @@ class Bot: await self._send_text( room.room_id, - f"To connect WildFiles, visit:\n\n" + f"To connect documents, visit:\n\n" f"**{verification_url}**\n\n" f"and enter code: **{user_code}**\n\n" f"_This link expires in 10 minutes._", @@ -1990,10 +2060,10 @@ class Bot: if sender and sender in self.user_keys: del self.user_keys[sender] self._save_user_keys() - await self._send_text(room.room_id, "Disconnected from WildFiles. Using default search.") + await self._send_text(room.room_id, "Custom document key removed. Using default document search.") logger.info("User %s disconnected WildFiles key", sender) else: - await self._send_text(room.room_id, "No WildFiles account connected.") + await self._send_text(room.room_id, "No custom document key connected.") async def _poll_device_auth(self, room_id: str, sender: str, device_code: str): """Poll WildFiles for device auth approval (5s interval, 10 min max).""" @@ -2017,7 +2087,7 @@ class Bot: self._save_user_keys() await self._send_text( room_id, - f"Connected to WildFiles (org: **{org_slug}**). Your documents are now searchable.", + f"Documents connected (org: **{org_slug}**). Your documents are now searchable.", ) logger.info("User %s connected via device auth (org: %s)", sender, org_slug) return @@ -2173,9 +2243,10 @@ class Bot: # Rewrite query using conversation context for better RAG search search_query = await self._rewrite_query(user_message, history, model) - # WildFiles document context (use per-user API key if available) + # WildFiles document context (portal org auto-provision, legacy API key fallback) user_api_key = self.user_keys.get(sender) if sender else None - doc_results = await self.rag.search(search_query, api_key=user_api_key) + user_org_slug = await self._get_wildfiles_org(sender) if sender else None + doc_results = await self.rag.search(search_query, api_key=user_api_key, org_slug=user_org_slug) doc_context = self.rag.format_context(doc_results) if doc_context: logger.info("RAG found %d docs for: %s (original: %s)", len(doc_results), search_query[:50], user_message[:50])