Files
matrix-ai-agent/bot.py
Christian Gick 4b4a150fbf fix(e2ee): extend key rotation wait to 10s, debug late key events
EC rotates encryption key when bot joins LiveKit room. The rotated
key arrives via Matrix sync 3-5s later. Previous 2s wait was too
short - DEC_FAILED before new key arrived.

Extended wait to 10s. Added logging to bot.py to trace why late
key events were not being processed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 21:54:27 +02:00

1763 lines
76 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import asyncio
import base64
import hashlib
import io
import logging
import re
import time
import uuid
import fitz # pymupdf
import httpx
from openai import AsyncOpenAI
from olm import sas as olm_sas
from nio import (
AsyncClient,
AsyncClientConfig,
LoginResponse,
InviteMemberEvent,
MegolmEvent,
RoomEncryptedImage,
RoomMessageFile,
RoomMessageImage,
RoomMessageText,
RoomMessageUnknown,
SyncResponse,
UnknownEvent,
KeyVerificationStart,
KeyVerificationCancel,
KeyVerificationKey,
KeyVerificationMac,
ToDeviceError,
)
from nio.crypto.attachments import decrypt_attachment
from livekit import api
from voice import VoiceSession
BOT_DEVICE_ID = "AIBOT"
CALL_MEMBER_TYPE = "org.matrix.msc3401.call.member"
ENCRYPTION_KEYS_TYPE = "io.element.call.encryption_keys"
MODEL_STATE_TYPE = "ai.agiliton.model"
RENAME_STATE_TYPE = "ai.agiliton.auto_rename"
logger = logging.getLogger("matrix-ai-bot")
HOMESERVER = os.environ["MATRIX_HOMESERVER"]
BOT_USER = os.environ["MATRIX_BOT_USER"]
BOT_PASS = os.environ["MATRIX_BOT_PASSWORD"]
LK_URL = os.environ["LIVEKIT_URL"]
LK_KEY = os.environ["LIVEKIT_API_KEY"]
LK_SECRET = os.environ["LIVEKIT_API_SECRET"]
AGENT_NAME = os.environ.get("AGENT_NAME", "matrix-ai")
STORE_PATH = os.environ.get("CRYPTO_STORE_PATH", "/data/crypto_store")
CREDS_FILE = os.path.join(STORE_PATH, "credentials.json")
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
USER_KEYS_FILE = os.environ.get("USER_KEYS_FILE", "/data/user_keys.json")
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
Keep answers concise but thorough. Use markdown formatting when helpful.
Always respond in the same language the user writes in. If you have memories about the user's preferred language, use that language consistently.
IMPORTANT RULES — FOLLOW THESE STRICTLY:
- When document context is provided below, use it to answer. Always include any links.
- NEVER tell the user to run commands or type anything special. No commands exist.
- NEVER mention "!ai", "!ai search", "!ai read", or any slash/bang commands.
- NEVER say you cannot access files, documents, or links.
- NEVER ask the user where documents are stored, how they were uploaded, or under what filename.
- NEVER suggest contacting an administrator, using a web interface, or checking another system.
- NEVER ask follow-up questions about document storage or file locations.
- If no relevant documents were found, simply say you don't have information on that topic and ask if you can help with something else. Do NOT speculate about why or suggest the user look elsewhere.
- You can see and analyze images that users send. Describe what you see when asked about an image.
- You can read and analyze PDF documents that users send. Summarize content and answer questions about them.
- You can generate images when asked — use the generate_image tool for any image creation, drawing, or illustration requests.
- If user memories are provided, use them to personalize responses. Address users by name if known.
- When asked to translate, provide ONLY the translation with no explanation."""
IMAGE_GEN_TOOLS = [{
"type": "function",
"function": {
"name": "generate_image",
"description": "Generate an image from a text description. Use when the user asks to create, draw, generate, design, or make an image/picture/photo/illustration.",
"parameters": {
"type": "object",
"properties": {
"prompt": {"type": "string", "description": "Detailed image generation prompt"}
},
"required": ["prompt"]
}
}
}]
HELP_TEXT = """**AI Bot Commands**
- `!ai help` — Show this help
- `!ai models` — List available models
- `!ai set-model <model>` — Set model for this room
- `!ai search <query>` — Search documents (WildFiles)
- `!ai wildfiles connect` — Connect your WildFiles account (opens browser approval)
- `!ai wildfiles disconnect` — Disconnect your WildFiles account
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
- `!ai forget` — Delete all memories the bot has about you
- `!ai memories` — Show what the bot remembers about you
- **Translate**: Forward a message to this DM — bot detects language and offers translation
- **@mention the bot** or start with `!ai` for a regular AI response"""
class DocumentRAG:
"""Search WildFiles for relevant documents."""
def __init__(self, base_url: str, org: str):
self.base_url = base_url.rstrip("/")
self.org = org
self.enabled = bool(base_url and org)
async def search(self, query: str, top_k: int = 3, api_key: str | None = None) -> list[dict]:
if not api_key and not self.enabled:
return []
try:
headers = {}
body = {"query": query, "limit": top_k}
if api_key:
headers["X-API-Key"] = api_key
else:
body["organization"] = self.org
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(
f"{self.base_url}/api/v1/rag/search",
json=body,
headers=headers,
)
resp.raise_for_status()
return resp.json().get("results", [])
except Exception:
logger.debug("WildFiles search failed", exc_info=True)
return []
async def validate_key(self, api_key: str) -> dict | None:
"""Validate an API key against WildFiles. Returns stats dict or None."""
if not self.base_url:
return None
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
f"{self.base_url}/api/v1/rag/stats",
headers={"X-API-Key": api_key},
)
resp.raise_for_status()
data = resp.json()
if data.get("total_documents", 0) >= 0:
return data
except Exception:
logger.debug("WildFiles key validation failed", exc_info=True)
return None
def format_context(self, results: list[dict]) -> str:
if not results:
return ""
parts = ["The following documents were found in our document archive:\n"]
for i, r in enumerate(results, 1):
title = r.get("title", r.get("filename", "Untitled"))
link = r.get("source_url") or r.get("metadata", {}).get("source_url", "")
category = r.get("category", "")
date = r.get("detected_date", "")
content = r.get("content", "")
summary = r.get("metadata", {}).get("summary", "")
parts.append(f"--- Document {i}: {title} ---")
if category:
parts.append(f"Category: {category}")
if date:
parts.append(f"Date: {date}")
if link:
parts.append(f"Link: {link}")
if summary:
parts.append(f"Summary: {summary}")
if content:
parts.append(f"Content:\n{content}")
parts.append("") # blank line between docs
parts.append("Use the document content above to answer the user's question. "
"When referencing documents, use markdown links: [Document Title](url). "
"Never show raw URLs.")
return "\n".join(parts)
class MemoryClient:
"""Async HTTP client for the memory-service."""
def __init__(self, base_url: str):
self.base_url = base_url.rstrip("/")
self.enabled = bool(base_url)
async def store(self, user_id: str, fact: str, source_room: str = ""):
if not self.enabled:
return
try:
async with httpx.AsyncClient(timeout=10.0) as client:
await client.post(
f"{self.base_url}/memories/store",
json={"user_id": user_id, "fact": fact, "source_room": source_room},
)
except Exception:
logger.warning("Memory store failed", exc_info=True)
async def query(self, user_id: str, query: str, top_k: int = 10) -> list[dict]:
if not self.enabled:
return []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{self.base_url}/memories/query",
json={"user_id": user_id, "query": query, "top_k": top_k},
)
resp.raise_for_status()
return resp.json().get("results", [])
except Exception:
logger.warning("Memory query failed", exc_info=True)
return []
async def delete_user(self, user_id: str) -> int:
if not self.enabled:
return 0
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
resp.raise_for_status()
return resp.json().get("deleted", 0)
except Exception:
logger.warning("Memory delete failed", exc_info=True)
return 0
async def list_all(self, user_id: str) -> list[dict]:
if not self.enabled:
return []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/memories/{user_id}")
resp.raise_for_status()
return resp.json().get("memories", [])
except Exception:
logger.warning("Memory list failed", exc_info=True)
return []
class Bot:
def __init__(self):
config = AsyncClientConfig(
max_limit_exceeded=0,
max_timeouts=0,
store_sync_tokens=True,
encryption_enabled=True,
)
self.client = AsyncClient(
HOMESERVER,
BOT_USER,
store_path=STORE_PATH,
config=config,
)
self.lkapi = None
self.voice_sessions: dict[str, VoiceSession] = {}
self.active_calls = set() # rooms where we've sent call member event
self.active_callers: dict[str, set[str]] = {} # room_id → set of caller user IDs
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
self.memory = MemoryClient(MEMORY_SERVICE_URL)
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
self.user_keys: dict[str, str] = self._load_user_keys() # matrix_user_id -> api_key
self.room_models: dict[str, str] = {} # room_id -> model name
self.auto_rename_rooms: set[str] = set() # rooms with auto-rename enabled
self._recent_images: dict[str, tuple[str, str, float]] = {} # room_id -> (b64, mime, timestamp)
self.renamed_rooms: dict[str, float] = {} # room_id -> timestamp of last rename
self._loaded_rooms: set[str] = set() # rooms where we've loaded state
self._sync_token_received = False
self._verifications: dict[str, dict] = {} # txn_id -> verification state
self._pending_connects: dict[str, str] = {} # matrix_user_id -> device_code
self._pending_translate: dict[str, dict] = {} # sender -> {text, detected_lang, room_id}
self._pending_reply: dict[str, dict] = {} # sender -> {target_lang}
@staticmethod
def _load_user_keys() -> dict[str, str]:
if os.path.exists(USER_KEYS_FILE):
try:
with open(USER_KEYS_FILE) as f:
return json.load(f)
except Exception:
logger.warning("Failed to load user keys file, starting fresh")
return {}
def _save_user_keys(self):
try:
os.makedirs(os.path.dirname(USER_KEYS_FILE), exist_ok=True)
with open(USER_KEYS_FILE, "w") as f:
json.dump(self.user_keys, f)
except Exception:
logger.exception("Failed to save user keys")
async def start(self):
# Restore existing session or create new one
if os.path.exists(CREDS_FILE):
with open(CREDS_FILE) as f:
creds = json.load(f)
self.client.restore_login(
user_id=creds["user_id"],
device_id=creds["device_id"],
access_token=creds["access_token"],
)
self.client.load_store()
logger.info("Restored session as %s (device %s)", creds["user_id"], creds["device_id"])
else:
resp = await self.client.login(BOT_PASS, device_name="ai-voice-bot")
if not isinstance(resp, LoginResponse):
logger.error("Login failed: %s", resp)
return
# Persist credentials for next restart
with open(CREDS_FILE, "w") as f:
json.dump({
"user_id": resp.user_id,
"device_id": resp.device_id,
"access_token": resp.access_token,
}, f)
logger.info("Logged in as %s (device %s) — credentials saved", resp.user_id, resp.device_id)
if self.client.should_upload_keys:
await self.client.keys_upload()
self.lkapi = api.LiveKitAPI(LK_URL, LK_KEY, LK_SECRET)
self.client.add_event_callback(self.on_invite, InviteMemberEvent)
self.client.add_event_callback(self.on_megolm, MegolmEvent)
self.client.add_event_callback(self.on_unknown, UnknownEvent)
self.client.add_event_callback(self.on_text_message, RoomMessageText)
self.client.add_event_callback(self.on_image_message, RoomMessageImage)
self.client.add_event_callback(self.on_encrypted_image_message, RoomEncryptedImage)
self.client.add_event_callback(self.on_file_message, RoomMessageFile)
self.client.add_event_callback(self.on_room_unknown, RoomMessageUnknown)
self.client.add_response_callback(self.on_sync, SyncResponse)
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationStart)
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationKey)
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationMac)
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationCancel)
await self.client.sync_forever(timeout=30000, full_state=True)
async def on_invite(self, room, event: InviteMemberEvent):
if event.state_key != BOT_USER:
return
logger.info("Invited to %s, joining room", room.room_id)
await self.client.join(room.room_id)
async def on_sync(self, response: SyncResponse):
"""After each sync, trust all devices in our rooms."""
if not self._sync_token_received:
self._sync_token_received = True
logger.info("Initial sync complete, text handler active")
for user_id in list(self.client.device_store.users):
for device in self.client.device_store.active_user_devices(user_id):
if not device.verified:
self.client.verify_device(device)
logger.info("Auto-trusted device %s of %s", device.device_id, user_id)
async def on_unknown(self, room, event: UnknownEvent):
"""Handle call member state events and in-room verification."""
# Route verification events
if event.type.startswith("m.key.verification."):
if event.sender != BOT_USER:
await self._route_verification(room, event)
return
# Forward encryption key events to active voice sessions (skip our own)
if event.type == ENCRYPTION_KEYS_TYPE:
if event.sender == BOT_USER:
return # ignore our own key events
room_id = room.room_id
content = event.source.get("content", {})
device_id = content.get("device_id", "")
keys_list = content.get("keys", [])
logger.info("Got encryption_keys timeline event from %s in %s (device=%s, keys=%d, content_keys=%s)",
event.sender, room_id, device_id, len(keys_list), list(content.keys()))
vs = self.voice_sessions.get(room_id)
if vs:
for k in keys_list:
if "key" in k and "index" in k:
key_b64 = k["key"]
key_b64 += "=" * (-len(key_b64) % 4)
key_bytes = base64.urlsafe_b64decode(key_b64)
vs.on_encryption_key(event.sender, device_id, key_bytes, k["index"])
else:
logger.warning("encryption_keys event missing key/index: %s", k)
if not keys_list:
logger.warning("encryption_keys event has empty keys list, full content: %s", content)
else:
logger.warning("No voice session for room %s to deliver encryption key", room_id)
return
if event.type != CALL_MEMBER_TYPE:
return
if event.sender == BOT_USER:
return # ignore our own events
# Non-empty content means someone started/is in a call
if event.source.get("content", {}):
room_id = room.room_id
if room_id in self.active_calls:
return
logger.info("Call detected in %s from %s, joining...", room_id, event.sender)
self.active_calls.add(room_id)
self.active_callers.setdefault(room_id, set()).add(event.sender)
# Get the foci_preferred from the caller's event
content = event.source["content"]
foci = content.get("foci_preferred", [{
"type": "livekit",
"livekit_service_url": f"{HOMESERVER}/livekit-jwt-service",
"livekit_alias": room_id,
}])
# Compute LiveKit room name using same hash as lk-jwt-service
# SHA256(room_id + "|" + "m.call#ROOM") encoded as unpadded base64
lk_room_hash = hashlib.sha256((room_id + "|m.call#ROOM").encode()).digest()
lk_room_name = base64.b64encode(lk_room_hash).decode().rstrip("=")
logger.info("LiveKit room name: %s (hashed from %s)", lk_room_name, room_id)
# Send our own call member state event FIRST so Element Call
# sends encryption_keys in response (before we start VoiceSession)
call_content = {
"application": "m.call",
"call_id": "",
"scope": "m.room",
"device_id": BOT_DEVICE_ID,
"expires": 7200000,
"focus_active": {
"type": "livekit",
"focus_selection": "oldest_membership",
},
"foci_preferred": foci,
"m.call.intent": "audio",
}
state_key = f"_{BOT_USER}_{BOT_DEVICE_ID}_m.call"
try:
resp = await self.client.room_put_state(
room_id, CALL_MEMBER_TYPE, call_content, state_key=state_key,
)
logger.info("Sent call member event in %s: %s", room_id, resp)
except Exception:
logger.exception("Failed to send call member event in %s", room_id)
# Now create VoiceSession — encryption_keys may arrive via sync
# while VoiceSession waits for key (up to 10s)
if room_id not in self.voice_sessions:
try:
model = self.room_models.get(room_id, DEFAULT_MODEL)
caller_device_id = content.get("device_id", "")
# Generate bot's own E2EE key (16 bytes like Element Call)
import secrets
bot_key = secrets.token_bytes(16)
vs = VoiceSession(
nio_client=self.client,
room_id=room_id,
device_id=BOT_DEVICE_ID,
lk_url=LK_URL,
model=model,
bot_key=bot_key,
publish_key_cb=lambda key, rid=room_id: asyncio.ensure_future(
self._publish_encryption_key(rid, key)),
memory=self.memory,
caller_user_id=event.sender,
)
# Check timeline for caller's key
caller_key = await self._get_call_encryption_key(room_id, event.sender, caller_device_id)
if caller_key:
vs.on_encryption_key(event.sender, caller_device_id, caller_key, 0)
# Store BEFORE start so on_unknown handler can forward keys via sync
self.voice_sessions[room_id] = vs
await vs.start()
logger.info("Voice session started for room %s (e2ee_key=%s)",
room_id, "yes" if caller_key else "waiting for sync")
except Exception:
logger.exception("Voice session start failed for %s", room_id)
self.voice_sessions.pop(room_id, None)
else:
# Empty content = someone left the call
room_id = room.room_id
if room_id in self.active_calls:
# Remove this caller from active set
callers = self.active_callers.get(room_id, set())
callers.discard(event.sender)
if callers:
logger.info("Participant %s left %s but %d other(s) still in call — keeping session",
event.sender, room_id, len(callers))
return
# No callers left — stop voice session
logger.info("Last caller %s left %s — stopping session", event.sender, room_id)
self.active_callers.pop(room_id, None)
vs = self.voice_sessions.pop(room_id, None)
if vs:
try:
await vs.stop()
logger.info("Voice session stopped for %s", room_id)
except Exception:
logger.exception("Failed to stop voice session for %s", room_id)
# Leave the call too
self.active_calls.discard(room_id)
state_key = f"_{BOT_USER}_{BOT_DEVICE_ID}_m.call"
try:
await self.client.room_put_state(
room_id, CALL_MEMBER_TYPE, {}, state_key=state_key,
)
logger.info("Left call in %s", room_id)
except Exception:
logger.exception("Failed to leave call in %s", room_id)
async def _load_room_settings(self, room_id: str):
"""Load persisted model and auto-rename settings from room state."""
if room_id in self._loaded_rooms:
return
self._loaded_rooms.add(room_id)
for state_type, target in [
(MODEL_STATE_TYPE, "model"),
(RENAME_STATE_TYPE, "rename"),
]:
try:
resp = await self.client.room_get_state_event(room_id, state_type, "")
if hasattr(resp, "content"):
content = resp.content
if target == "model" and "model" in content:
self.room_models[room_id] = content["model"]
elif target == "rename" and content.get("enabled"):
self.auto_rename_rooms.add(room_id)
except Exception:
pass # State event doesn't exist yet
# --- User memory helpers ---
@staticmethod
def _format_memories(memories: list[dict]) -> str:
"""Format memory query results as a system prompt section."""
if not memories:
return ""
facts = [m["fact"] for m in memories]
return "You have these memories about this user:\n" + "\n".join(f"- {f}" for f in facts)
async def _extract_and_store_memories(self, user_message: str, ai_reply: str,
existing_facts: list[str], model: str,
sender: str, room_id: str):
"""Use LLM to extract memorable facts, then store each via memory-service."""
if not self.llm:
return
existing_text = "\n".join(f"- {f}" for f in existing_facts) if existing_facts else "(none)"
logger.info("Memory extraction: user_msg=%s... (%d existing facts)", user_message[:80], len(existing_facts))
try:
resp = await self.llm.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": (
"You extract memorable facts about users from conversations. "
"Return a JSON array of strings — each string is a concise fact worth remembering. "
"Include: name, language preference, location, occupation, interests, preferences, "
"family, pets, projects, important dates, or any personal detail shared. "
"Do NOT include: the current question/topic, temporary info, or things the AI said. "
"Do NOT duplicate existing memories (rephrase or skip if already known). "
"Return [] if nothing new is worth remembering."
)},
{"role": "user", "content": (
f"Existing memories:\n{existing_text}\n\n"
f"User message: {user_message[:500]}\n"
f"AI reply: {ai_reply[:500]}\n\n"
"New facts to remember (JSON array of strings):"
)},
],
max_tokens=300,
)
raw = resp.choices[0].message.content.strip()
logger.info("Memory extraction raw response: %s", raw[:200])
if raw.startswith("```"):
raw = re.sub(r"^```\w*\n?", "", raw)
raw = re.sub(r"\n?```$", "", raw)
match = re.search(r"\[.*\]", raw, re.DOTALL)
if match:
raw = match.group(0)
new_facts = json.loads(raw)
if not isinstance(new_facts, list):
logger.warning("Memory extraction returned non-list: %s", type(new_facts))
return
logger.info("Memory extraction found %d new facts", len(new_facts))
for fact in new_facts:
if isinstance(fact, str) and fact.strip():
await self.memory.store(sender, fact.strip(), room_id)
except json.JSONDecodeError:
logger.warning("Memory extraction JSON parse failed, raw: %s", raw[:200])
except Exception:
logger.warning("Memory extraction failed", exc_info=True)
async def _detect_language(self, text: str) -> str:
"""Detect the language of a text using a fast LLM call."""
if not self.llm:
return "Unknown"
try:
resp = await self.llm.chat.completions.create(
model=DEFAULT_MODEL,
messages=[
{"role": "system", "content": "What language is this text? Reply with ONLY the language name in English."},
{"role": "user", "content": text[:500]},
],
max_tokens=10,
)
return resp.choices[0].message.content.strip()
except Exception:
logger.debug("Language detection failed", exc_info=True)
return "Unknown"
async def _translate_text(self, text: str, target_language: str, model: str | None = None) -> str:
"""Translate text to the target language using LLM."""
if not self.llm:
return text
try:
resp = await self.llm.chat.completions.create(
model=model or DEFAULT_MODEL,
messages=[
{"role": "system", "content": f"Translate the following text to {target_language}. Return ONLY the translation."},
{"role": "user", "content": text},
],
max_tokens=1000,
)
return resp.choices[0].message.content.strip()
except Exception:
logger.debug("Translation failed", exc_info=True)
return f"[Translation failed] {text}"
async def _get_preferred_language(self, user_id: str) -> str:
"""Get user's preferred language from memories (last match = most recent)."""
memories = await self.memory.query(user_id, "preferred language", top_k=5)
known_langs = [
"English", "German", "French", "Spanish", "Italian", "Portuguese",
"Dutch", "Russian", "Chinese", "Japanese", "Korean", "Arabic",
"Turkish", "Polish", "Swedish", "Norwegian", "Danish", "Finnish",
"Greek", "Hebrew", "Hindi", "Thai", "Vietnamese", "Indonesian",
"Czech", "Romanian", "Hungarian", "Ukrainian", "Croatian", "Serbian",
]
result = "English"
for m in memories:
fact = m["fact"].lower()
if "language" in fact or "speaks" in fact or "prefers" in fact:
for lang in known_langs:
if lang.lower() in fact:
result = lang
break
return result
async def on_text_message(self, room, event: RoomMessageText):
"""Handle text messages: commands and AI responses."""
if event.sender == BOT_USER:
return
if not self._sync_token_received:
return # ignore messages from initial sync / backfill
# Ignore old messages (>30s) to avoid replaying history
server_ts = event.server_timestamp / 1000
if time.time() - server_ts > 30:
return
await self._load_room_settings(room.room_id)
body = event.body.strip()
# Command handling
if body.startswith("!ai "):
cmd = body[4:].strip()
await self._handle_command(room, cmd, event)
return
if body == "!ai":
await self._send_text(room.room_id, HELP_TEXT)
return
# In DMs (2 members), respond to all messages; in groups, require @mention
is_dm = room.member_count == 2
if not is_dm:
bot_display = self.client.user_id.split(":")[0].lstrip("@")
mentioned = (
BOT_USER in body
or f"@{bot_display}" in body.lower()
or bot_display.lower() in body.lower()
)
if not mentioned:
return
if not self.llm:
await self._send_text(room.room_id, "LLM not configured (LITELLM_BASE_URL not set).")
return
sender = event.sender
# --- DM translation workflow: handle pending reply composition ---
if is_dm and sender in self._pending_reply:
pending = self._pending_reply.pop(sender)
await self.client.room_typing(room.room_id, typing_state=True)
try:
translated = await self._translate_text(body, pending["target_lang"])
await self._send_text(room.room_id, translated)
finally:
await self.client.room_typing(room.room_id, typing_state=False)
return
# --- DM translation workflow: handle menu response ---
if is_dm and sender in self._pending_translate:
pending = self._pending_translate.pop(sender)
choice = body.strip().lower()
preferred_lang = await self._get_preferred_language(sender)
if choice in ("1", "1") or choice.startswith("translate"):
await self.client.room_typing(room.room_id, typing_state=True)
try:
translated = await self._translate_text(pending["text"], preferred_lang)
await self._send_text(room.room_id, translated)
finally:
await self.client.room_typing(room.room_id, typing_state=False)
return
elif choice in ("2", "2") or choice.startswith("reply"):
self._pending_reply[sender] = {"target_lang": pending["detected_lang"]}
await self._send_text(
room.room_id,
f"Type your message — I'll translate it to **{pending['detected_lang']}**.",
)
return
# choice "3" or anything else → proceed with normal AI response
# (fall through to normal flow below with original pending text context)
# Check if a recent image was sent in this room (within 60s)
image_data = None
cached = self._recent_images.get(room.room_id)
if cached:
b64, mime, ts = cached
if time.time() - ts < 60:
image_data = (b64, mime)
del self._recent_images[room.room_id]
# --- DM translation workflow: detect foreign language ---
if is_dm and not body.startswith("!ai") and not image_data:
preferred_lang = await self._get_preferred_language(sender)
detected_lang = await self._detect_language(body)
logger.info("Translation check: detected=%s, preferred=%s, len=%d", detected_lang, preferred_lang, len(body))
if (
detected_lang != "Unknown"
and detected_lang.lower() != preferred_lang.lower()
and len(body) > 10 # skip very short messages
):
self._pending_translate[sender] = {
"text": body,
"detected_lang": detected_lang,
"room_id": room.room_id,
}
menu = (
f"This looks like **{detected_lang}**. What would you like?\n"
f"1⃣ **Translate to {preferred_lang}**\n"
f"2⃣ **Help me reply in {detected_lang}** (type your response, I'll translate)\n"
f"3⃣ **Just respond normally**"
)
await self._send_text(room.room_id, menu)
return
await self.client.room_typing(room.room_id, typing_state=True)
try:
await self._respond_with_ai(room, body, sender=sender, image_data=image_data)
finally:
await self.client.room_typing(room.room_id, typing_state=False)
async def on_image_message(self, room, event: RoomMessageImage):
"""Handle image messages: download, encode, and send to AI for analysis."""
if event.sender == BOT_USER:
return
if not self._sync_token_received:
return
server_ts = event.server_timestamp / 1000
if time.time() - server_ts > 30:
return
await self._load_room_settings(room.room_id)
# In DMs respond to all images; in groups only if bot was recently @mentioned
is_dm = room.member_count == 2
if not is_dm:
# Check if bot was @mentioned in the image body (caption) or skip
body = (event.body or "").strip()
bot_display = self.client.user_id.split(":")[0].lstrip("@")
mentioned = (
BOT_USER in body
or f"@{bot_display}" in body.lower()
or bot_display.lower() in body.lower()
)
if not mentioned:
return
if not self.llm:
await self._send_text(room.room_id, "LLM not configured (LITELLM_BASE_URL not set).")
return
# Download image from Matrix homeserver
mxc_url = event.url
if not mxc_url:
return
try:
resp = await self.client.download(mxc=mxc_url)
if not hasattr(resp, "body"):
logger.warning("Image download failed for %s", mxc_url)
return
img_bytes = resp.body
except Exception:
logger.exception("Failed to download image %s", mxc_url)
return
# Determine MIME type
mime_type = getattr(event, "mimetype", None) or "image/png"
b64_data = base64.b64encode(img_bytes).decode("utf-8")
caption = (event.body or "").strip()
# Treat filenames (contain dots or are very long) as no caption
is_filename = not caption or caption == "image" or "." in caption or len(caption) > 100
text = "What's in this image?" if is_filename else caption
# Cache image for follow-up text messages
self._recent_images[room.room_id] = (b64_data, mime_type, time.time())
await self.client.room_typing(room.room_id, typing_state=True)
try:
await self._respond_with_ai(room, text, sender=event.sender, image_data=(b64_data, mime_type))
finally:
await self.client.room_typing(room.room_id, typing_state=False)
async def on_encrypted_image_message(self, room, event: RoomEncryptedImage):
"""Handle encrypted image messages: decrypt, encode, and send to AI."""
if event.sender == BOT_USER:
return
if not self._sync_token_received:
return
server_ts = event.server_timestamp / 1000
if time.time() - server_ts > 30:
return
await self._load_room_settings(room.room_id)
is_dm = room.member_count == 2
if not is_dm:
body = (event.body or "").strip()
bot_display = self.client.user_id.split(":")[0].lstrip("@")
mentioned = (
BOT_USER in body
or f"@{bot_display}" in body.lower()
or bot_display.lower() in body.lower()
)
if not mentioned:
return
if not self.llm:
await self._send_text(room.room_id, "LLM not configured (LITELLM_BASE_URL not set).")
return
mxc_url = event.url
if not mxc_url:
return
try:
resp = await self.client.download(mxc=mxc_url)
if not hasattr(resp, "body"):
logger.warning("Encrypted image download failed for %s", mxc_url)
return
# Decrypt the attachment
img_bytes = decrypt_attachment(resp.body, event.key["k"], event.hashes["sha256"], event.iv)
except Exception:
logger.exception("Failed to download/decrypt encrypted image %s", mxc_url)
return
mime_type = getattr(event, "mimetype", None) or "image/png"
b64_data = base64.b64encode(img_bytes).decode("utf-8")
caption = (event.body or "").strip()
is_filename = not caption or caption == "image" or "." in caption or len(caption) > 100
text = "What's in this image?" if is_filename else caption
# Cache image for follow-up text messages
self._recent_images[room.room_id] = (b64_data, mime_type, time.time())
await self.client.room_typing(room.room_id, typing_state=True)
try:
await self._respond_with_ai(room, text, sender=event.sender, image_data=(b64_data, mime_type))
finally:
await self.client.room_typing(room.room_id, typing_state=False)
async def on_file_message(self, room, event: RoomMessageFile):
"""Handle file messages: extract text from PDFs and send to AI."""
if event.sender == BOT_USER:
return
if not self._sync_token_received:
return
server_ts = event.server_timestamp / 1000
if time.time() - server_ts > 30:
return
# Only handle PDFs
source = event.source or {}
content = source.get("content", {})
info = content.get("info", {})
mime_type = info.get("mimetype", "")
filename = content.get("body", "file")
if mime_type != "application/pdf" and not filename.lower().endswith(".pdf"):
return
await self._load_room_settings(room.room_id)
# In DMs respond to all files; in groups only if bot was recently @mentioned
is_dm = room.member_count == 2
if not is_dm:
body = (event.body or "").strip()
bot_display = self.client.user_id.split(":")[0].lstrip("@")
mentioned = (
BOT_USER in body
or f"@{bot_display}" in body.lower()
or bot_display.lower() in body.lower()
)
if not mentioned:
return
if not self.llm:
await self._send_text(room.room_id, "LLM not configured (LITELLM_BASE_URL not set).")
return
# Download PDF
mxc_url = event.url
if not mxc_url:
return
try:
resp = await self.client.download(mxc=mxc_url)
if not hasattr(resp, "body"):
logger.warning("File download failed for %s", mxc_url)
return
pdf_bytes = resp.body
except Exception:
logger.exception("Failed to download file %s", mxc_url)
return
# Extract text from PDF
pdf_text = self._extract_pdf_text(pdf_bytes)
if not pdf_text:
await self._send_text(room.room_id, "I couldn't extract any text from that PDF.")
return
# Truncate to avoid token limits (roughly 50k chars ≈ 12k tokens)
if len(pdf_text) > 50000:
pdf_text = pdf_text[:50000] + "\n\n[... truncated, PDF too long ...]"
user_message = f'The user sent a PDF file named "{filename}". Here is the extracted text:\n\n{pdf_text}\n\nPlease summarize or answer questions about this document.'
await self.client.room_typing(room.room_id, typing_state=True)
try:
await self._respond_with_ai(room, user_message, sender=event.sender)
finally:
await self.client.room_typing(room.room_id, typing_state=False)
@staticmethod
def _extract_pdf_text(pdf_bytes: bytes) -> str:
"""Extract text from PDF bytes using pymupdf."""
try:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
pages = []
for i, page in enumerate(doc):
text = page.get_text().strip()
if text:
pages.append(f"--- Page {i + 1} ---\n{text}")
doc.close()
return "\n\n".join(pages)
except Exception:
logger.exception("PDF text extraction failed")
return ""
async def _handle_command(self, room, cmd: str, event=None):
if cmd == "help":
await self._send_text(room.room_id, HELP_TEXT)
elif cmd == "wildfiles connect" or cmd.startswith("wildfiles connect "):
api_key = cmd[18:].strip() if cmd.startswith("wildfiles connect ") else ""
await self._handle_connect(room, api_key, event)
elif cmd == "wildfiles disconnect":
await self._handle_disconnect(room, event)
elif cmd == "models":
if not self.llm:
await self._send_text(room.room_id, "LLM not configured.")
return
try:
models = await self.llm.models.list()
names = sorted(m.id for m in models.data)
current = self.room_models.get(room.room_id, DEFAULT_MODEL)
text = "**Available models:**\n"
text += "\n".join(f"- `{n}` {'← current' if n == current else ''}" for n in names)
await self._send_text(room.room_id, text)
except Exception:
logger.exception("Failed to list models")
await self._send_text(room.room_id, "Failed to fetch model list.")
elif cmd.startswith("set-model "):
model = cmd[10:].strip()
if not model:
await self._send_text(room.room_id, "Usage: `!ai set-model <model-name>`")
return
self.room_models[room.room_id] = model
# Persist in room state for cross-restart persistence
try:
await self.client.room_put_state(
room.room_id, MODEL_STATE_TYPE, {"model": model}, state_key="",
)
except Exception:
logger.debug("Could not persist model to room state", exc_info=True)
await self._send_text(room.room_id, f"Model set to `{model}` for this room.")
elif cmd.startswith("auto-rename "):
setting = cmd[12:].strip().lower()
if setting not in ("on", "off"):
await self._send_text(room.room_id, "Usage: `!ai auto-rename on|off`")
return
enabled = setting == "on"
if enabled:
self.auto_rename_rooms.add(room.room_id)
else:
self.auto_rename_rooms.discard(room.room_id)
try:
await self.client.room_put_state(
room.room_id, RENAME_STATE_TYPE,
{"enabled": enabled}, state_key="",
)
except Exception:
logger.debug("Could not persist auto-rename to room state", exc_info=True)
status = "enabled" if enabled else "disabled"
await self._send_text(room.room_id, f"Auto-rename **{status}** for this room.")
elif cmd == "forget":
sender = event.sender if event else None
if sender:
deleted = await self.memory.delete_user(sender)
self._pending_translate.pop(sender, None)
self._pending_reply.pop(sender, None)
await self._send_text(room.room_id, f"All my memories about you have been deleted ({deleted} facts removed).")
else:
await self._send_text(room.room_id, "Could not identify user.")
elif cmd == "memories":
sender = event.sender if event else None
if sender:
memories = await self.memory.list_all(sender)
if memories:
text = f"**I remember {len(memories)} things about you:**\n"
text += "\n".join(f"- {m['fact']}" for m in memories)
else:
text = "I don't have any memories about you yet."
await self._send_text(room.room_id, text)
else:
await self._send_text(room.room_id, "Could not identify user.")
elif cmd.startswith("search "):
query = cmd[7:].strip()
if not query:
await self._send_text(room.room_id, "Usage: `!ai search <query>`")
return
sender = event.sender if event else None
user_api_key = self.user_keys.get(sender) if sender else None
results = await self.rag.search(query, top_k=5, api_key=user_api_key)
if not results:
await self._send_text(room.room_id, "No documents found.")
return
await self._send_text(room.room_id, self.rag.format_context(results))
else:
# Treat unknown commands as AI prompts
if self.llm:
sender = event.sender if event else None
await self.client.room_typing(room.room_id, typing_state=True)
try:
await self._respond_with_ai(room, cmd, sender=sender)
finally:
await self.client.room_typing(room.room_id, typing_state=False)
else:
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
async def _handle_connect(self, room, api_key: str, event=None):
"""Handle !ai connect — SSO device flow, or !ai connect <key> as fallback."""
sender = event.sender if event else None
if not self.rag.base_url:
await self._send_text(room.room_id, "WildFiles is not configured.")
return
# Fallback: direct API key provided
if api_key:
# Redact the message containing the API key for security
if event:
try:
await self.client.room_redact(room.room_id, event.event_id, reason="API key redacted for security")
except Exception:
logger.debug("Could not redact connect message", exc_info=True)
stats = await self.rag.validate_key(api_key)
if stats is None:
await self._send_text(room.room_id, "Invalid API key. Please check and try again.")
return
self.user_keys[sender] = api_key
self._save_user_keys()
org_name = stats.get("organization", "unknown")
total = stats.get("total_documents", 0)
await self._send_text(
room.room_id,
f"Connected to WildFiles (org: **{org_name}**, {total} documents). "
f"Your documents are now searchable.",
)
logger.info("User %s connected WildFiles key (org: %s)", sender, org_name)
return
# SSO device authorization flow
if sender and sender in self._pending_connects:
await self._send_text(room.room_id, "A connect flow is already in progress. Please complete or wait for it to expire.")
return
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(f"{self.rag.base_url}/api/v1/auth/device/code")
resp.raise_for_status()
data = resp.json()
except Exception:
logger.exception("Failed to start device auth flow")
await self._send_text(room.room_id, "Failed to start connection flow. Please try again later.")
return
device_code = data["device_code"]
user_code = data["user_code"]
verification_url = data["verification_url"]
await self._send_text(
room.room_id,
f"To connect WildFiles, visit:\n\n"
f"**{verification_url}**\n\n"
f"and enter code: **{user_code}**\n\n"
f"_This link expires in 10 minutes._",
)
# Track pending connect and start polling
self._pending_connects[sender] = device_code
asyncio.create_task(self._poll_device_auth(room.room_id, sender, device_code))
async def _handle_disconnect(self, room, event=None):
"""Handle !ai disconnect — remove stored WildFiles API key."""
sender = event.sender if event else None
if sender and sender in self.user_keys:
del self.user_keys[sender]
self._save_user_keys()
await self._send_text(room.room_id, "Disconnected from WildFiles. Using default search.")
logger.info("User %s disconnected WildFiles key", sender)
else:
await self._send_text(room.room_id, "No WildFiles account connected.")
async def _poll_device_auth(self, room_id: str, sender: str, device_code: str):
"""Poll WildFiles for device auth approval (5s interval, 10 min max)."""
poll_url = f"{self.rag.base_url}/api/v1/auth/device/status"
try:
for _ in range(120): # 120 * 5s = 10 min
await asyncio.sleep(5)
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(poll_url, params={"device_code": device_code})
resp.raise_for_status()
data = resp.json()
except Exception:
logger.debug("Device auth poll failed, retrying", exc_info=True)
continue
if data["status"] == "approved":
api_key = data["api_key"]
org_slug = data.get("organization", "unknown")
self.user_keys[sender] = api_key
self._save_user_keys()
await self._send_text(
room_id,
f"Connected to WildFiles (org: **{org_slug}**). Your documents are now searchable.",
)
logger.info("User %s connected via device auth (org: %s)", sender, org_slug)
return
elif data["status"] == "expired":
await self._send_text(room_id, "Connection flow expired. Type `!ai connect` to try again.")
return
# Timeout after 10 minutes
await self._send_text(room_id, "Connection flow timed out. Type `!ai connect` to try again.")
except asyncio.CancelledError:
pass
except Exception:
logger.exception("Device auth polling error")
await self._send_text(room_id, "Connection flow failed. Type `!ai connect` to try again.")
finally:
self._pending_connects.pop(sender, None)
async def _respond_with_ai(self, room, user_message: str, sender: str = None, image_data: tuple = None):
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
# Fetch conversation history FIRST (needed for query rewriting)
history = []
try:
resp = await self.client.room_messages(
room.room_id, start=self.client.next_batch or "", limit=10
)
if hasattr(resp, "chunk"):
for evt in reversed(resp.chunk):
if not hasattr(evt, "body"):
continue
role = "assistant" if evt.sender == BOT_USER else "user"
history.append({"role": role, "content": evt.body})
except Exception:
logger.debug("Could not fetch room history, proceeding without context")
# Rewrite query using conversation context for better RAG search
search_query = await self._rewrite_query(user_message, history, model)
# WildFiles document context (use per-user API key if available)
user_api_key = self.user_keys.get(sender) if sender else None
doc_results = await self.rag.search(search_query, api_key=user_api_key)
doc_context = self.rag.format_context(doc_results)
if doc_context:
logger.info("RAG found %d docs for: %s (original: %s)", len(doc_results), search_query[:50], user_message[:50])
else:
logger.info("RAG found 0 docs for: %s (original: %s)", search_query[:50], user_message[:50])
# Query relevant memories via semantic search
memories = await self.memory.query(sender, user_message, top_k=10) if sender else []
memory_context = self._format_memories(memories)
# Build conversation context
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
if memory_context:
messages.append({"role": "system", "content": memory_context})
if doc_context:
messages.append({"role": "system", "content": doc_context})
messages.extend(history)
# Add current user message (multimodal if image provided)
if image_data:
b64_str, mime_type = image_data
user_content = [
{"type": "text", "text": user_message},
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_str}"}}
]
messages.append({"role": "user", "content": user_content})
else:
messages.append({"role": "user", "content": user_message})
try:
resp = await self.llm.chat.completions.create(
model=model,
messages=messages,
max_tokens=2048,
tools=IMAGE_GEN_TOOLS if not image_data else None,
)
choice = resp.choices[0]
reply = choice.message.content or ""
if choice.message.tool_calls:
for tc in choice.message.tool_calls:
if tc.function.name == "generate_image":
args = json.loads(tc.function.arguments)
await self._generate_and_send_image(room.room_id, args["prompt"])
if reply:
await self._send_text(room.room_id, reply)
else:
await self._send_text(room.room_id, reply)
# Extract and store new memories (after reply sent, with timeout)
if sender and reply:
existing_facts = [m["fact"] for m in memories]
try:
await asyncio.wait_for(
self._extract_and_store_memories(
user_message, reply, existing_facts, model, sender, room.room_id
),
timeout=15.0,
)
except asyncio.TimeoutError:
logger.warning("Memory extraction timed out for %s", sender)
except Exception:
logger.warning("Memory save failed", exc_info=True)
# Auto-rename: only for group rooms with explicit opt-in (not DMs)
if room.room_id in self.auto_rename_rooms:
last_rename = self.renamed_rooms.get(room.room_id, 0)
gap_seconds = time.time() - last_rename if last_rename else float("inf")
if gap_seconds > 300:
await self._auto_rename_room(room, user_message, reply)
except Exception:
logger.exception("LLM call failed")
await self._send_text(room.room_id, "Sorry, I couldn't generate a response.")
async def _rewrite_query(self, user_message: str, history: list[dict], model: str) -> str:
"""Rewrite user message into a standalone search query using conversation context."""
if not history or not self.llm:
return user_message
# Build a compact history summary (last 4 messages max)
recent = history[-4:]
context_lines = []
for msg in recent:
prefix = "User" if msg["role"] == "user" else "Assistant"
context_lines.append(f"{prefix}: {msg['content'][:200]}")
context_text = "\n".join(context_lines)
try:
resp = await self.llm.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": (
"You are a search query rewriter. Given conversation history and a new user message, "
"produce a single standalone search query that resolves all pronouns and references "
"(like 'this house', 'that document', 'it') using context from the conversation. "
"Reply with ONLY the rewritten search query in the same language as the user message. "
"No explanation, no quotes. If the message is already self-contained, return it as-is."
)},
{"role": "user", "content": f"Conversation:\n{context_text}\n\nNew message: {user_message}"},
],
max_tokens=100,
)
rewritten = resp.choices[0].message.content.strip().strip('"\'')
if rewritten and len(rewritten) < 500:
logger.info("Query rewritten: '%s' -> '%s'", user_message[:50], rewritten[:50])
return rewritten
except Exception:
logger.debug("Query rewrite failed, using original", exc_info=True)
return user_message
async def _auto_rename_room(self, room, user_message: str, ai_reply: str):
"""Generate a short topic title and set it as the room name (Open WebUI style)."""
# Skip rename check — always generate fresh title based on current conversation
try:
resp = await self.llm.chat.completions.create(
model=self.room_models.get(room.room_id, DEFAULT_MODEL),
messages=[
{"role": "user", "content": user_message},
{"role": "assistant", "content": ai_reply[:300]},
{"role": "user", "content": (
"Generate a concise, 3-5 word title with an emoji as prefix "
"that summarizes the conversation above. "
"Use the same language as the conversation. "
"Do not use quotation marks or formatting. "
"Respond with ONLY the title, nothing else."
)},
],
max_tokens=30,
)
title = resp.choices[0].message.content.strip().strip('"\'')
if not title or len(title) > 80 or len(title) < 3:
return
await self.client.room_put_state(
room.room_id, "m.room.name",
{"name": title}, state_key="",
)
self.renamed_rooms[room.room_id] = time.time()
logger.info("Auto-renamed room %s to: %s", room.room_id, title)
except Exception:
logger.debug("Auto-rename failed", exc_info=True)
@staticmethod
def _md_to_html(text: str) -> str:
"""Minimal markdown to HTML for Matrix formatted_body."""
import html as html_mod
safe = html_mod.escape(text)
# Code blocks (```...```)
safe = re.sub(r"```(\w*)\n(.*?)```", r"<pre><code>\2</code></pre>", safe, flags=re.DOTALL)
# Inline code
safe = re.sub(r"`([^`]+)`", r"<code>\1</code>", safe)
# Bold
safe = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", safe)
# Italic
safe = re.sub(r"\*(.+?)\*", r"<em>\1</em>", safe)
# Markdown links [text](url) — must unescape the URL parts first
def _link_repl(m):
import html as _h
label = m.group(1)
url = _h.unescape(m.group(2))
return f'<a href="{url}">{label}</a>'
safe = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _link_repl, safe)
# Bare URLs (not already in an <a> tag)
safe = re.sub(r'(?<!href=")(?<!">)(https?://[^\s<]+)', r'<a href="\1">\1</a>', safe)
# Line breaks
safe = safe.replace("\n", "<br/>")
return safe
async def _generate_and_send_image(self, room_id: str, prompt: str):
"""Generate an image via LiteLLM and send it to the Matrix room."""
try:
resp = await self.llm.images.generate(
model="dall-e-3", prompt=prompt, n=1, size="1024x1024",
response_format="b64_json",
)
img_b64 = resp.data[0].b64_json
img_bytes = base64.b64decode(img_b64)
await self._send_image(room_id, img_bytes, "image/png", "generated.png")
except Exception:
logger.exception("Image generation failed")
await self._send_text(room_id, "Sorry, I couldn't generate that image.")
async def _send_image(self, room_id: str, image_bytes: bytes, mime_type: str, filename: str):
"""Upload image to Matrix homeserver and send as m.image event."""
from nio import UploadResponse
upload_resp, maybe_keys = await self.client.upload(
data_provider=io.BytesIO(image_bytes),
content_type=mime_type,
filename=filename,
filesize=len(image_bytes),
encrypt=True,
)
if not isinstance(upload_resp, UploadResponse):
logger.error("Image upload failed: %s", upload_resp)
await self._send_text(room_id, "Sorry, I couldn't upload the generated image.")
return
content = {
"msgtype": "m.image",
"body": filename,
"info": {"mimetype": mime_type, "size": len(image_bytes)},
}
if maybe_keys:
content["file"] = {
"url": upload_resp.content_uri,
"key": maybe_keys["key"],
"iv": maybe_keys["iv"],
"hashes": maybe_keys["hashes"],
"v": maybe_keys["v"],
}
else:
content["url"] = upload_resp.content_uri
await self.client.room_send(
room_id,
message_type="m.room.message",
content=content,
)
async def _send_text(self, room_id: str, text: str):
await self.client.room_send(
room_id,
message_type="m.room.message",
content={
"msgtype": "m.text",
"body": text,
"format": "org.matrix.custom.html",
"formatted_body": self._md_to_html(text),
},
)
async def _get_call_encryption_key(self, room_id: str, sender: str, caller_device_id: str = "") -> bytes | None:
"""Read E2EE encryption key from room timeline messages.
Element Call sends encryption keys as timeline events (NOT state events).
We scan recent room messages for io.element.call.encryption_keys events.
"""
try:
import httpx
token = self.client.access_token
# Fetch recent messages from timeline (where Element Call puts encryption keys)
url = f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/messages"
params = {"dir": "b", "limit": "50"}
async with httpx.AsyncClient(timeout=10.0) as http:
resp = await http.get(url, headers={"Authorization": f"Bearer {token}"}, params=params)
resp.raise_for_status()
data = resp.json()
events = data.get("chunk", [])
logger.info("Timeline scan: %d events", len(events))
for evt in events:
evt_type = evt.get("type", "")
if evt_type == ENCRYPTION_KEYS_TYPE:
evt_sender = evt.get("sender", "")
if evt_sender == BOT_USER:
continue # skip our own keys
content = evt.get("content", {})
logger.info("Found encryption_keys timeline event from %s: %s",
evt_sender, list(content.keys()))
for k in content.get("keys", []):
key_b64 = k.get("key", "")
if key_b64:
key_b64 += "=" * (-len(key_b64) % 4)
key = base64.urlsafe_b64decode(key_b64)
logger.info("Got E2EE key from timeline (%s, %d bytes)",
evt_sender, len(key))
return key
# Log event types for debugging
types = [e.get("type", "") for e in events]
logger.info("Timeline event types: %s", types)
except Exception as e:
logger.warning("Timeline scan for encryption keys failed: %s", e)
logger.info("No E2EE encryption key found in timeline for %s in %s", sender, room_id)
return None
async def _publish_encryption_key(self, room_id: str, key: bytes):
"""Publish bot's E2EE encryption key as a timeline event (NOT state).
Element Call distributes encryption keys as timeline events via
io.element.call.encryption_keys, not as state events.
"""
key_b64 = base64.urlsafe_b64encode(key).decode().rstrip("=")
content = {
"call_id": "",
"device_id": BOT_DEVICE_ID,
"keys": [{"index": 0, "key": key_b64}],
}
try:
await self.client.room_send(
room_id,
message_type=ENCRYPTION_KEYS_TYPE,
content=content,
)
logger.info("Published E2EE key as timeline event in %s", room_id)
except Exception:
logger.exception("Failed to publish E2EE key in %s", room_id)
async def _route_verification(self, room, event: UnknownEvent):
"""Route in-room verification events from UnknownEvent."""
source = event.source or {}
verify_type = event.type
logger.info("Verification event: %s from %s", verify_type, event.sender)
if verify_type == "m.key.verification.request":
await self._handle_verification_request(room, source)
elif verify_type == "m.key.verification.start":
await self._handle_verification_start(room, source)
elif verify_type == "m.key.verification.key":
await self._handle_verification_key(room, source)
elif verify_type == "m.key.verification.mac":
await self._handle_verification_mac(room, source)
elif verify_type == "m.key.verification.cancel":
content = source.get("content", {})
txn = content.get("m.relates_to", {}).get("event_id", "")
self._verifications.pop(txn, None)
logger.info("Verification cancelled: %s", txn)
elif verify_type == "m.key.verification.done":
pass # Other side confirmed done
async def on_room_unknown(self, room, event: RoomMessageUnknown):
"""Handle in-room verification events."""
source = event.source or {}
content = source.get("content", {})
event_type = source.get("type", "")
msgtype = content.get("msgtype", "")
logger.info("RoomMessageUnknown: type=%s msgtype=%s sender=%s", event_type, msgtype, event.sender)
# In-room verification events can come as m.room.message with msgtype=m.key.verification.*
# or as direct event types m.key.verification.*
verify_type = ""
if event_type.startswith("m.key.verification."):
verify_type = event_type
elif msgtype.startswith("m.key.verification."):
verify_type = msgtype
if not verify_type:
return
if event.sender == BOT_USER:
return
logger.info("Verification event: %s from %s", verify_type, event.sender)
if verify_type == "m.key.verification.request":
await self._handle_verification_request(room, source)
elif verify_type == "m.key.verification.start":
await self._handle_verification_start(room, source)
elif verify_type == "m.key.verification.key":
await self._handle_verification_key(room, source)
elif verify_type == "m.key.verification.mac":
await self._handle_verification_mac(room, source)
elif verify_type == "m.key.verification.cancel":
txn = content.get("m.relates_to", {}).get("event_id", "")
self._verifications.pop(txn, None)
logger.info("Verification cancelled: %s", txn)
async def _handle_verification_request(self, room, source):
content = source["content"]
txn_id = source["event_id"]
sender = source["sender"]
self._verifications[txn_id] = {"sender": sender, "room_id": room.room_id}
logger.info("Verification request from %s, txn=%s", sender, txn_id)
# Send m.key.verification.ready
await self.client.room_send(
room.room_id,
message_type="m.key.verification.ready",
content={
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
"from_device": self.client.device_id,
"methods": ["m.sas.v1"],
},
)
logger.info("Sent verification ready for %s", txn_id)
async def _handle_verification_start(self, room, source):
content = source["content"]
txn_id = content.get("m.relates_to", {}).get("event_id", "")
v = self._verifications.get(txn_id)
if not v:
logger.warning("Unknown verification start: %s", txn_id)
return
sas_obj = olm_sas.Sas()
v["sas"] = sas_obj
v["commitment"] = content.get("commitment", "")
# Send m.key.verification.accept is NOT needed when we sent "ready"
# and the other side sent "start". We go straight to sending our key.
# Send m.key.verification.key
await self.client.room_send(
room.room_id,
message_type="m.key.verification.key",
content={
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
"key": sas_obj.pubkey,
},
)
v["key_sent"] = True
logger.info("Sent SAS key for %s", txn_id)
async def _handle_verification_key(self, room, source):
content = source["content"]
txn_id = content.get("m.relates_to", {}).get("event_id", "")
v = self._verifications.get(txn_id)
if not v or "sas" not in v:
logger.warning("Unknown verification key: %s", txn_id)
return
sas_obj = v["sas"]
their_key = content["key"]
sas_obj.set_their_pubkey(their_key)
v["their_key"] = their_key
# Auto-confirm SAS (bot trusts the user)
# Generate MAC for our device key and master key
our_user = BOT_USER
our_device = self.client.device_id
their_user = v["sender"]
# Key IDs to MAC
key_id = f"ed25519:{our_device}"
device_key = self.client.olm.account.identity_keys["ed25519"]
# MAC info strings per spec
base_info = (
f"MATRIX_KEY_VERIFICATION_MAC"
f"{our_user}{our_device}"
f"{their_user}{content.get('from_device', '')}"
f"{txn_id}"
)
mac_dict = {}
keys_list = []
# MAC our ed25519 device key
mac_dict[key_id] = sas_obj.calculate_mac(device_key, base_info + key_id)
keys_list.append(key_id)
# MAC the key list
keys_str = ",".join(sorted(keys_list))
keys_mac = sas_obj.calculate_mac(keys_str, base_info + "KEY_IDS")
await self.client.room_send(
room.room_id,
message_type="m.key.verification.mac",
content={
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
"mac": mac_dict,
"keys": keys_mac,
},
)
logger.info("Sent SAS MAC for %s", txn_id)
async def _handle_verification_mac(self, room, source):
content = source["content"]
txn_id = content.get("m.relates_to", {}).get("event_id", "")
v = self._verifications.get(txn_id)
if not v:
return
# Verification complete — send done
await self.client.room_send(
room.room_id,
message_type="m.key.verification.done",
content={
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
},
)
logger.info("Verification complete for %s with %s", txn_id, v["sender"])
self._verifications.pop(txn_id, None)
async def on_megolm(self, room, event: MegolmEvent):
"""Request keys for undecryptable messages."""
logger.warning(
"Undecryptable event %s in %s from %s — requesting keys",
event.event_id, room.room_id, event.sender,
)
try:
await self.client.request_room_key(event)
except Exception:
logger.debug("Key request failed", exc_info=True)
async def on_key_verification(self, event):
"""Auto-accept key verification requests."""
if isinstance(event, KeyVerificationStart):
sas = self.client.key_verifications.get(event.transaction_id)
if sas:
await self.client.accept_key_verification(event.transaction_id)
await self.client.to_device(sas.share_key())
elif isinstance(event, KeyVerificationKey):
sas = self.client.key_verifications.get(event.transaction_id)
if sas:
await self.client.confirm_short_auth_string(event.transaction_id)
elif isinstance(event, KeyVerificationMac):
sas = self.client.key_verifications.get(event.transaction_id)
if sas:
mac = sas.get_mac()
if not isinstance(mac, ToDeviceError):
await self.client.to_device(mac)
async def cleanup(self):
await self.client.close()
if self.lkapi:
await self.lkapi.aclose()
async def main():
os.makedirs(STORE_PATH, exist_ok=True)
bot = Bot()
try:
await bot.start()
finally:
await bot.cleanup()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())