diff --git a/bot.py b/bot.py index 8b818d8..907a692 100644 --- a/bot.py +++ b/bot.py @@ -73,6 +73,8 @@ CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "") CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "") PORTAL_URL = os.environ.get("PORTAL_URL", "") BOT_API_KEY = os.environ.get("BOT_API_KEY", "") +RAG_ENDPOINT = os.environ.get("RAG_ENDPOINT", "") # Customer-VM RAG service (e.g. http://127.0.0.1:8765) +RAG_AUTH_TOKEN = os.environ.get("RAG_AUTH_TOKEN", "") # Bearer token for local RAG BRAVE_API_KEY = os.environ.get("BRAVE_API_KEY", "") MAX_TOOL_ITERATIONS = 5 @@ -316,16 +318,49 @@ Manage settings at [matrixhost.eu/settings](https://matrixhost.eu/settings).""" class DocumentRAG: - """Search documents via MatrixHost API (replaces WildFiles).""" + """Search documents via customer-VM RAG service or central portal fallback.""" - def __init__(self, portal_url: str, bot_api_key: str): + def __init__(self, portal_url: str, bot_api_key: str, + rag_endpoint: str = "", rag_auth_token: str = ""): self.portal_url = portal_url.rstrip("/") self.bot_api_key = bot_api_key - self.enabled = bool(portal_url and bot_api_key) + self.rag_endpoint = rag_endpoint.rstrip("/") if rag_endpoint else "" + self.rag_auth_token = rag_auth_token + self.use_local_rag = bool(self.rag_endpoint) + self.enabled = bool(self.rag_endpoint) or bool(portal_url and bot_api_key) async def search(self, query: str, top_k: int = 3, api_key: str | None = None, org_slug: str | None = None, matrix_user_id: str | None = None) -> list[dict]: if not self.enabled or not matrix_user_id: return [] + + # Prefer customer-VM RAG service (encrypted, local) + if self.use_local_rag: + return await self._search_local(query, top_k) + + # Fallback: central portal API (legacy, unencrypted) + return await self._search_portal(query, top_k, matrix_user_id) + + async def _search_local(self, query: str, top_k: int) -> list[dict]: + """Search via customer-VM RAG service (localhost).""" + try: + body = {"query": query, "limit": top_k} + headers: dict[str, str] = {"Content-Type": "application/json"} + if self.rag_auth_token: + headers["Authorization"] = f"Bearer {self.rag_auth_token}" + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.post( + f"{self.rag_endpoint}/rag/search", + json=body, + headers=headers, + ) + resp.raise_for_status() + return resp.json().get("results", []) + except Exception: + logger.debug("Local RAG search failed", exc_info=True) + return [] + + async def _search_portal(self, query: str, top_k: int, matrix_user_id: str) -> list[dict]: + """Search via central portal API (legacy fallback).""" try: body = {"query": query, "limit": top_k, "matrix_user_id": matrix_user_id} async with httpx.AsyncClient(timeout=15.0) as client: @@ -337,16 +372,23 @@ class DocumentRAG: resp.raise_for_status() return resp.json().get("results", []) except Exception: - logger.debug("Document search failed", exc_info=True) + logger.debug("Portal document search failed", exc_info=True) return [] - async def validate_key(self, api_key: str) -> dict | None: - """Legacy: no longer used (keys replaced by portal auth).""" - return None - - async def get_org_stats(self, org_slug: str) -> dict | None: - """Legacy: no longer used.""" - return None + async def health(self) -> dict | None: + """Check local RAG service health.""" + if not self.use_local_rag: + return None + try: + headers: dict[str, str] = {} + if self.rag_auth_token: + headers["Authorization"] = f"Bearer {self.rag_auth_token}" + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"{self.rag_endpoint}/health", headers=headers) + resp.raise_for_status() + return resp.json() + except Exception: + return None def format_context(self, results: list[dict]) -> str: if not results: @@ -896,7 +938,8 @@ class Bot: self.voice_sessions: dict[str, VoiceSession] = {} self.active_calls = set() # rooms where we've sent call member event self.active_callers: dict[str, set[str]] = {} # room_id → set of caller user IDs - self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY) + self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY, + rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN) self.memory = MemoryClient(MEMORY_SERVICE_URL) self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY) self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None @@ -911,13 +954,21 @@ class Bot: self._room_document_context: dict[str, list[dict]] = {} # room_id -> [{type, filename, text, timestamp}, ...] async def _has_documents(self, matrix_user_id: str) -> bool: - """Check if user has documents via MatrixHost portal API. + """Check if user has documents via local RAG or MatrixHost portal API. Results are cached per session. """ if matrix_user_id in self._documents_cache: return self._documents_cache[matrix_user_id] is not None + # Check local RAG service first (customer-VM encrypted RAG) + if self.rag.use_local_rag: + health = await self.rag.health() + if health and health.get("document_count", 0) > 0: + self._documents_cache[matrix_user_id] = "connected" + return True + + # Fallback: check via central portal if self.atlassian.enabled: try: async with httpx.AsyncClient(timeout=10.0) as client: