feat: Add per-user WildFiles auth via !ai connect/disconnect
- !ai connect <key>: validates key against WildFiles, stores per-user mapping, redacts message - !ai disconnect: removes stored key - RAG searches use per-user API key when available, fall back to WILDFILES_ORG - Keys stored in /data/user_keys.json (Docker volume) Implements WF-90 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
127
bot.py
127
bot.py
@@ -50,6 +50,7 @@ LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
|||||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
||||||
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
|
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
|
||||||
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
|
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
|
||||||
|
USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json")
|
||||||
|
|
||||||
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
|
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
|
||||||
Keep answers concise but thorough. Use markdown formatting when helpful.
|
Keep answers concise but thorough. Use markdown formatting when helpful.
|
||||||
@@ -69,6 +70,8 @@ HELP_TEXT = """**AI Bot Commands**
|
|||||||
- `!ai models` — List available models
|
- `!ai models` — List available models
|
||||||
- `!ai set-model <model>` — Set model for this room
|
- `!ai set-model <model>` — Set model for this room
|
||||||
- `!ai search <query>` — Search documents (WildFiles)
|
- `!ai search <query>` — Search documents (WildFiles)
|
||||||
|
- `!ai connect <api-key>` — Connect your WildFiles account (DM only)
|
||||||
|
- `!ai disconnect` — Disconnect your WildFiles account
|
||||||
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
|
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
|
||||||
- **@mention the bot** or start with `!ai` for a regular AI response"""
|
- **@mention the bot** or start with `!ai` for a regular AI response"""
|
||||||
|
|
||||||
@@ -81,14 +84,21 @@ class DocumentRAG:
|
|||||||
self.org = org
|
self.org = org
|
||||||
self.enabled = bool(base_url and org)
|
self.enabled = bool(base_url and org)
|
||||||
|
|
||||||
async def search(self, query: str, top_k: int = 3) -> list[dict]:
|
async def search(self, query: str, top_k: int = 3, api_key: str | None = None) -> list[dict]:
|
||||||
if not self.enabled:
|
if not api_key and not self.enabled:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
|
headers = {}
|
||||||
|
body = {"query": query, "limit": top_k}
|
||||||
|
if api_key:
|
||||||
|
headers["X-API-Key"] = api_key
|
||||||
|
else:
|
||||||
|
body["organization"] = self.org
|
||||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{self.base_url}/api/v1/rag/search",
|
f"{self.base_url}/api/v1/rag/search",
|
||||||
json={"query": query, "organization": self.org, "limit": top_k},
|
json=body,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json().get("results", [])
|
return resp.json().get("results", [])
|
||||||
@@ -96,6 +106,24 @@ class DocumentRAG:
|
|||||||
logger.debug("WildFiles search failed", exc_info=True)
|
logger.debug("WildFiles search failed", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def validate_key(self, api_key: str) -> dict | None:
|
||||||
|
"""Validate an API key against WildFiles. Returns stats dict or None."""
|
||||||
|
if not self.base_url:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{self.base_url}/api/v1/rag/stats",
|
||||||
|
headers={"X-API-Key": api_key},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
if data.get("total_documents", 0) >= 0:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
logger.debug("WildFiles key validation failed", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
def format_context(self, results: list[dict]) -> str:
|
def format_context(self, results: list[dict]) -> str:
|
||||||
if not results:
|
if not results:
|
||||||
return ""
|
return ""
|
||||||
@@ -146,6 +174,7 @@ class Bot:
|
|||||||
self.active_calls = set() # rooms where we've sent call member event
|
self.active_calls = set() # rooms where we've sent call member event
|
||||||
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
|
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
|
||||||
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
||||||
|
self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key
|
||||||
self.room_models: dict[str, str] = {} # room_id -> model name
|
self.room_models: dict[str, str] = {} # room_id -> model name
|
||||||
self.auto_rename_rooms: set[str] = set() # rooms with auto-rename enabled
|
self.auto_rename_rooms: set[str] = set() # rooms with auto-rename enabled
|
||||||
self.renamed_rooms: dict[str, float] = {} # room_id -> timestamp of last rename
|
self.renamed_rooms: dict[str, float] = {} # room_id -> timestamp of last rename
|
||||||
@@ -153,6 +182,24 @@ class Bot:
|
|||||||
self._sync_token_received = False
|
self._sync_token_received = False
|
||||||
self._verifications: dict[str, dict] = {} # txn_id -> verification state
|
self._verifications: dict[str, dict] = {} # txn_id -> verification state
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_user_keys() -> dict[str, str]:
|
||||||
|
if os.path.exists(USER_KEYS_FILE):
|
||||||
|
try:
|
||||||
|
with open(USER_KEYS_FILE) as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to load user keys file, starting fresh")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save_user_keys(self):
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(USER_KEYS_FILE), exist_ok=True)
|
||||||
|
with open(USER_KEYS_FILE, "w") as f:
|
||||||
|
json.dump(self.user_keys, f)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to save user keys")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
# Restore existing session or create new one
|
# Restore existing session or create new one
|
||||||
if os.path.exists(CREDS_FILE):
|
if os.path.exists(CREDS_FILE):
|
||||||
@@ -340,7 +387,7 @@ class Bot:
|
|||||||
# Command handling
|
# Command handling
|
||||||
if body.startswith("!ai "):
|
if body.startswith("!ai "):
|
||||||
cmd = body[4:].strip()
|
cmd = body[4:].strip()
|
||||||
await self._handle_command(room, cmd)
|
await self._handle_command(room, cmd, event)
|
||||||
return
|
return
|
||||||
if body == "!ai":
|
if body == "!ai":
|
||||||
await self._send_text(room.room_id, HELP_TEXT)
|
await self._send_text(room.room_id, HELP_TEXT)
|
||||||
@@ -364,14 +411,20 @@ class Bot:
|
|||||||
|
|
||||||
await self.client.room_typing(room.room_id, typing_state=True)
|
await self.client.room_typing(room.room_id, typing_state=True)
|
||||||
try:
|
try:
|
||||||
await self._respond_with_ai(room, body)
|
await self._respond_with_ai(room, body, sender=event.sender)
|
||||||
finally:
|
finally:
|
||||||
await self.client.room_typing(room.room_id, typing_state=False)
|
await self.client.room_typing(room.room_id, typing_state=False)
|
||||||
|
|
||||||
async def _handle_command(self, room, cmd: str):
|
async def _handle_command(self, room, cmd: str, event=None):
|
||||||
if cmd == "help":
|
if cmd == "help":
|
||||||
await self._send_text(room.room_id, HELP_TEXT)
|
await self._send_text(room.room_id, HELP_TEXT)
|
||||||
|
|
||||||
|
elif cmd.startswith("connect "):
|
||||||
|
await self._handle_connect(room, cmd[8:].strip(), event)
|
||||||
|
|
||||||
|
elif cmd == "disconnect":
|
||||||
|
await self._handle_disconnect(room, event)
|
||||||
|
|
||||||
elif cmd == "models":
|
elif cmd == "models":
|
||||||
if not self.llm:
|
if not self.llm:
|
||||||
await self._send_text(room.room_id, "LLM not configured.")
|
await self._send_text(room.room_id, "LLM not configured.")
|
||||||
@@ -427,7 +480,9 @@ class Bot:
|
|||||||
if not query:
|
if not query:
|
||||||
await self._send_text(room.room_id, "Usage: `!ai search <query>`")
|
await self._send_text(room.room_id, "Usage: `!ai search <query>`")
|
||||||
return
|
return
|
||||||
results = await self.rag.search(query, top_k=5)
|
sender = event.sender if event else None
|
||||||
|
user_api_key = self.user_keys.get(sender) if sender else None
|
||||||
|
results = await self.rag.search(query, top_k=5, api_key=user_api_key)
|
||||||
if not results:
|
if not results:
|
||||||
await self._send_text(room.room_id, "No documents found.")
|
await self._send_text(room.room_id, "No documents found.")
|
||||||
return
|
return
|
||||||
@@ -436,15 +491,64 @@ class Bot:
|
|||||||
else:
|
else:
|
||||||
# Treat unknown commands as AI prompts
|
# Treat unknown commands as AI prompts
|
||||||
if self.llm:
|
if self.llm:
|
||||||
|
sender = event.sender if event else None
|
||||||
await self.client.room_typing(room.room_id, typing_state=True)
|
await self.client.room_typing(room.room_id, typing_state=True)
|
||||||
try:
|
try:
|
||||||
await self._respond_with_ai(room, cmd)
|
await self._respond_with_ai(room, cmd, sender=sender)
|
||||||
finally:
|
finally:
|
||||||
await self.client.room_typing(room.room_id, typing_state=False)
|
await self.client.room_typing(room.room_id, typing_state=False)
|
||||||
else:
|
else:
|
||||||
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
|
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
|
||||||
|
|
||||||
async def _respond_with_ai(self, room, user_message: str):
|
async def _handle_connect(self, room, api_key: str, event=None):
|
||||||
|
"""Handle !ai connect <api-key> — validate and store user's WildFiles API key."""
|
||||||
|
sender = event.sender if event else None
|
||||||
|
if not api_key:
|
||||||
|
await self._send_text(room.room_id, "Usage: `!ai connect <wildfiles-api-key>`")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Redact the message containing the API key for security
|
||||||
|
if event:
|
||||||
|
try:
|
||||||
|
await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not redact connect message", exc_info=True)
|
||||||
|
|
||||||
|
if not self.rag.base_url:
|
||||||
|
await self._send_text(room.room_id, "WildFiles is not configured.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate the key
|
||||||
|
stats = await self.rag.validate_key(api_key)
|
||||||
|
if stats is None:
|
||||||
|
await self._send_text(room.room_id, "Invalid API key. Please check and try again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store the key
|
||||||
|
self.user_keys[sender] = api_key
|
||||||
|
self._save_user_keys()
|
||||||
|
|
||||||
|
org_name = stats.get("organization", "unknown")
|
||||||
|
total = stats.get("total_documents", 0)
|
||||||
|
await self._send_text(
|
||||||
|
room.room_id,
|
||||||
|
f"Connected to WildFiles (org: **{org_name}**, {total} documents). "
|
||||||
|
f"Your documents are now searchable.",
|
||||||
|
)
|
||||||
|
logger.info("User %s connected WildFiles key (org: %s)", sender, org_name)
|
||||||
|
|
||||||
|
async def _handle_disconnect(self, room, event=None):
|
||||||
|
"""Handle !ai disconnect — remove stored WildFiles API key."""
|
||||||
|
sender = event.sender if event else None
|
||||||
|
if sender and sender in self.user_keys:
|
||||||
|
del self.user_keys[sender]
|
||||||
|
self._save_user_keys()
|
||||||
|
await self._send_text(room.room_id, "Disconnected from WildFiles. Using default search.")
|
||||||
|
logger.info("User %s disconnected WildFiles key", sender)
|
||||||
|
else:
|
||||||
|
await self._send_text(room.room_id, "No WildFiles account connected.")
|
||||||
|
|
||||||
|
async def _respond_with_ai(self, room, user_message: str, sender: str = None):
|
||||||
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
||||||
|
|
||||||
# Fetch conversation history FIRST (needed for query rewriting)
|
# Fetch conversation history FIRST (needed for query rewriting)
|
||||||
@@ -465,8 +569,9 @@ class Bot:
|
|||||||
# Rewrite query using conversation context for better RAG search
|
# Rewrite query using conversation context for better RAG search
|
||||||
search_query = await self._rewrite_query(user_message, history, model)
|
search_query = await self._rewrite_query(user_message, history, model)
|
||||||
|
|
||||||
# WildFiles document context
|
# WildFiles document context (use per-user API key if available)
|
||||||
doc_results = await self.rag.search(search_query)
|
user_api_key = self.user_keys.get(sender) if sender else None
|
||||||
|
doc_results = await self.rag.search(search_query, api_key=user_api_key)
|
||||||
doc_context = self.rag.format_context(doc_results)
|
doc_context = self.rag.format_context(doc_results)
|
||||||
if doc_context:
|
if doc_context:
|
||||||
logger.info("RAG found %d docs for: %s (original: %s)", len(doc_results), search_query[:50], user_message[:50])
|
logger.info("RAG found %d docs for: %s (original: %s)", len(doc_results), search_query[:50], user_message[:50])
|
||||||
|
|||||||
Reference in New Issue
Block a user