source_url is now a top-level field on DocumentChunk, not nested in metadata. Fall back to metadata for backwards compatibility. Refs: WF-90 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
770 lines
31 KiB
Python
770 lines
31 KiB
Python
import os
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
import time
|
|
import uuid
|
|
|
|
import httpx
|
|
from openai import AsyncOpenAI
|
|
from olm import sas as olm_sas
|
|
|
|
from nio import (
|
|
AsyncClient,
|
|
AsyncClientConfig,
|
|
LoginResponse,
|
|
InviteMemberEvent,
|
|
MegolmEvent,
|
|
RoomMessageText,
|
|
RoomMessageUnknown,
|
|
SyncResponse,
|
|
UnknownEvent,
|
|
KeyVerificationStart,
|
|
KeyVerificationCancel,
|
|
KeyVerificationKey,
|
|
KeyVerificationMac,
|
|
ToDeviceError,
|
|
)
|
|
from livekit import api
|
|
|
|
BOT_DEVICE_ID = "AIBOT"
|
|
CALL_MEMBER_TYPE = "org.matrix.msc3401.call.member"
|
|
MODEL_STATE_TYPE = "ai.agiliton.model"
|
|
RENAME_STATE_TYPE = "ai.agiliton.auto_rename"
|
|
|
|
logger = logging.getLogger("matrix-ai-bot")
|
|
|
|
HOMESERVER = os.environ["MATRIX_HOMESERVER"]
|
|
BOT_USER = os.environ["MATRIX_BOT_USER"]
|
|
BOT_PASS = os.environ["MATRIX_BOT_PASSWORD"]
|
|
LK_URL = os.environ["LIVEKIT_URL"]
|
|
LK_KEY = os.environ["LIVEKIT_API_KEY"]
|
|
LK_SECRET = os.environ["LIVEKIT_API_SECRET"]
|
|
AGENT_NAME = os.environ.get("AGENT_NAME", "matrix-ai")
|
|
STORE_PATH = os.environ.get("CRYPTO_STORE_PATH", "/data/crypto_store")
|
|
CREDS_FILE = os.path.join(STORE_PATH, "credentials.json")
|
|
|
|
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
|
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
|
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
|
WILDFILES_BASE_URL = os.environ.get("WILDFILES_BASE_URL", "")
|
|
WILDFILES_ORG = os.environ.get("WILDFILES_ORG", "")
|
|
|
|
SYSTEM_PROMPT = """You are a helpful AI assistant in a Matrix chat room.
|
|
Keep answers concise but thorough. Use markdown formatting when helpful.
|
|
|
|
IMPORTANT RULES — FOLLOW THESE STRICTLY:
|
|
- When document context is provided below, use it to answer. Always include any links.
|
|
- NEVER tell the user to run commands or type anything special. No commands exist.
|
|
- NEVER mention "!ai", "!ai search", "!ai read", or any slash/bang commands.
|
|
- NEVER say you cannot access files, documents, or links.
|
|
- NEVER ask the user where documents are stored, how they were uploaded, or under what filename.
|
|
- NEVER suggest contacting an administrator, using a web interface, or checking another system.
|
|
- NEVER ask follow-up questions about document storage or file locations.
|
|
- If no relevant documents were found, simply say you don't have information on that topic and ask if you can help with something else. Do NOT speculate about why or suggest the user look elsewhere."""
|
|
|
|
HELP_TEXT = """**AI Bot Commands**
|
|
- `!ai help` — Show this help
|
|
- `!ai models` — List available models
|
|
- `!ai set-model <model>` — Set model for this room
|
|
- `!ai search <query>` — Search documents (WildFiles)
|
|
- `!ai auto-rename on|off` — Auto-rename room based on conversation topic
|
|
- **@mention the bot** or start with `!ai` for a regular AI response"""
|
|
|
|
|
|
class DocumentRAG:
|
|
"""Search WildFiles for relevant documents."""
|
|
|
|
def __init__(self, base_url: str, org: str):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.org = org
|
|
self.enabled = bool(base_url and org)
|
|
|
|
async def search(self, query: str, top_k: int = 3) -> list[dict]:
|
|
if not self.enabled:
|
|
return []
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
resp = await client.post(
|
|
f"{self.base_url}/api/v1/rag/search",
|
|
json={"query": query, "organization": self.org, "limit": top_k},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json().get("results", [])
|
|
except Exception:
|
|
logger.debug("WildFiles search failed", exc_info=True)
|
|
return []
|
|
|
|
def format_context(self, results: list[dict]) -> str:
|
|
if not results:
|
|
return ""
|
|
parts = ["The following documents were found in our document archive:"]
|
|
for r in results:
|
|
doc_id = r.get("id", "")
|
|
title = r.get("title", r.get("filename", "Untitled"))
|
|
filename = r.get("metadata", {}).get("original_filename", "")
|
|
category = r.get("category", "")
|
|
date = r.get("detected_date", "")
|
|
link = r.get("source_url") or r.get("metadata", {}).get("source_url", "")
|
|
parts.append(f"- Title: {title}")
|
|
if filename:
|
|
parts.append(f" Filename: {filename}")
|
|
if category:
|
|
parts.append(f" Category: {category}")
|
|
if date:
|
|
parts.append(f" Date: {date}")
|
|
if link:
|
|
parts.append(f" Link: {link}")
|
|
parts.append("\nUse this information to answer the user. Always include document links when referencing documents.")
|
|
return "\n".join(parts)
|
|
|
|
|
|
class Bot:
|
|
def __init__(self):
|
|
config = AsyncClientConfig(
|
|
max_limit_exceeded=0,
|
|
max_timeouts=0,
|
|
store_sync_tokens=True,
|
|
encryption_enabled=True,
|
|
)
|
|
self.client = AsyncClient(
|
|
HOMESERVER,
|
|
BOT_USER,
|
|
store_path=STORE_PATH,
|
|
config=config,
|
|
)
|
|
self.lkapi = None
|
|
self.dispatched_rooms = set()
|
|
self.active_calls = set() # rooms where we've sent call member event
|
|
self.rag = DocumentRAG(WILDFILES_BASE_URL, WILDFILES_ORG)
|
|
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
|
self.room_models: dict[str, str] = {} # room_id -> model name
|
|
self.auto_rename_rooms: set[str] = set() # rooms with auto-rename enabled
|
|
self.renamed_rooms: set[str] = set() # rooms already renamed this session
|
|
self._loaded_rooms: set[str] = set() # rooms where we've loaded state
|
|
self._sync_token_received = False
|
|
self._verifications: dict[str, dict] = {} # txn_id -> verification state
|
|
|
|
async def start(self):
|
|
# Restore existing session or create new one
|
|
if os.path.exists(CREDS_FILE):
|
|
with open(CREDS_FILE) as f:
|
|
creds = json.load(f)
|
|
self.client.restore_login(
|
|
user_id=creds["user_id"],
|
|
device_id=creds["device_id"],
|
|
access_token=creds["access_token"],
|
|
)
|
|
self.client.load_store()
|
|
logger.info("Restored session as %s (device %s)", creds["user_id"], creds["device_id"])
|
|
else:
|
|
resp = await self.client.login(BOT_PASS, device_name="ai-voice-bot")
|
|
if not isinstance(resp, LoginResponse):
|
|
logger.error("Login failed: %s", resp)
|
|
return
|
|
# Persist credentials for next restart
|
|
with open(CREDS_FILE, "w") as f:
|
|
json.dump({
|
|
"user_id": resp.user_id,
|
|
"device_id": resp.device_id,
|
|
"access_token": resp.access_token,
|
|
}, f)
|
|
logger.info("Logged in as %s (device %s) — credentials saved", resp.user_id, resp.device_id)
|
|
|
|
if self.client.should_upload_keys:
|
|
await self.client.keys_upload()
|
|
|
|
self.lkapi = api.LiveKitAPI(LK_URL, LK_KEY, LK_SECRET)
|
|
self.client.add_event_callback(self.on_invite, InviteMemberEvent)
|
|
self.client.add_event_callback(self.on_megolm, MegolmEvent)
|
|
self.client.add_event_callback(self.on_unknown, UnknownEvent)
|
|
self.client.add_event_callback(self.on_text_message, RoomMessageText)
|
|
self.client.add_event_callback(self.on_room_unknown, RoomMessageUnknown)
|
|
self.client.add_response_callback(self.on_sync, SyncResponse)
|
|
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationStart)
|
|
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationKey)
|
|
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationMac)
|
|
self.client.add_to_device_callback(self.on_key_verification, KeyVerificationCancel)
|
|
|
|
await self.client.sync_forever(timeout=30000, full_state=True)
|
|
|
|
async def on_invite(self, room, event: InviteMemberEvent):
|
|
if event.state_key != BOT_USER:
|
|
return
|
|
logger.info("Invited to %s, joining room", room.room_id)
|
|
await self.client.join(room.room_id)
|
|
|
|
async def on_sync(self, response: SyncResponse):
|
|
"""After each sync, trust all devices in our rooms."""
|
|
if not self._sync_token_received:
|
|
self._sync_token_received = True
|
|
logger.info("Initial sync complete, text handler active")
|
|
for user_id in list(self.client.device_store.users):
|
|
for device in self.client.device_store.active_user_devices(user_id):
|
|
if not device.verified:
|
|
self.client.verify_device(device)
|
|
logger.info("Auto-trusted device %s of %s", device.device_id, user_id)
|
|
|
|
async def on_unknown(self, room, event: UnknownEvent):
|
|
"""Handle call member state events and in-room verification."""
|
|
# Route verification events
|
|
if event.type.startswith("m.key.verification."):
|
|
if event.sender != BOT_USER:
|
|
await self._route_verification(room, event)
|
|
return
|
|
|
|
if event.type != CALL_MEMBER_TYPE:
|
|
return
|
|
if event.sender == BOT_USER:
|
|
return # ignore our own events
|
|
|
|
# Non-empty content means someone started/is in a call
|
|
if event.source.get("content", {}):
|
|
room_id = room.room_id
|
|
if room_id in self.active_calls:
|
|
return
|
|
|
|
logger.info("Call detected in %s from %s, joining...", room_id, event.sender)
|
|
self.active_calls.add(room_id)
|
|
|
|
# Get the foci_preferred from the caller's event
|
|
content = event.source["content"]
|
|
foci = content.get("foci_preferred", [{
|
|
"type": "livekit",
|
|
"livekit_service_url": f"{HOMESERVER}/livekit-jwt-service",
|
|
"livekit_alias": room_id,
|
|
}])
|
|
|
|
# Extract LiveKit room name from foci and dispatch agent
|
|
lk_room_name = room_id # fallback
|
|
for f in foci:
|
|
if f.get("type") == "livekit" and f.get("livekit_alias"):
|
|
lk_room_name = f["livekit_alias"]
|
|
break
|
|
logger.info("LiveKit room name: %s (from foci_preferred)", lk_room_name)
|
|
|
|
if room_id not in self.dispatched_rooms:
|
|
try:
|
|
await self.lkapi.agent_dispatch.create_dispatch(
|
|
api.CreateAgentDispatchRequest(
|
|
agent_name=AGENT_NAME,
|
|
room=lk_room_name,
|
|
)
|
|
)
|
|
self.dispatched_rooms.add(room_id)
|
|
logger.info("Agent dispatched to LiveKit room %s", lk_room_name)
|
|
except Exception:
|
|
logger.exception("Dispatch failed for %s", lk_room_name)
|
|
|
|
# Send our own call member state event
|
|
call_content = {
|
|
"application": "m.call",
|
|
"call_id": "",
|
|
"scope": "m.room",
|
|
"device_id": BOT_DEVICE_ID,
|
|
"expires": 7200000,
|
|
"focus_active": {
|
|
"type": "livekit",
|
|
"focus_selection": "oldest_membership",
|
|
},
|
|
"foci_preferred": foci,
|
|
"m.call.intent": "audio",
|
|
}
|
|
|
|
state_key = f"_{BOT_USER}_{BOT_DEVICE_ID}_m.call"
|
|
try:
|
|
resp = await self.client.room_put_state(
|
|
room_id, CALL_MEMBER_TYPE, call_content, state_key=state_key,
|
|
)
|
|
logger.info("Sent call member event in %s: %s", room_id, resp)
|
|
except Exception:
|
|
logger.exception("Failed to send call member event in %s", room_id)
|
|
|
|
else:
|
|
# Empty content = someone left the call, check if anyone is still calling
|
|
room_id = room.room_id
|
|
if room_id in self.active_calls:
|
|
# Leave the call too
|
|
self.active_calls.discard(room_id)
|
|
state_key = f"_{BOT_USER}_{BOT_DEVICE_ID}_m.call"
|
|
try:
|
|
await self.client.room_put_state(
|
|
room_id, CALL_MEMBER_TYPE, {}, state_key=state_key,
|
|
)
|
|
logger.info("Left call in %s", room_id)
|
|
except Exception:
|
|
logger.exception("Failed to leave call in %s", room_id)
|
|
|
|
async def _load_room_settings(self, room_id: str):
|
|
"""Load persisted model and auto-rename settings from room state."""
|
|
if room_id in self._loaded_rooms:
|
|
return
|
|
self._loaded_rooms.add(room_id)
|
|
for state_type, target in [
|
|
(MODEL_STATE_TYPE, "model"),
|
|
(RENAME_STATE_TYPE, "rename"),
|
|
]:
|
|
try:
|
|
resp = await self.client.room_get_state_event(room_id, state_type, "")
|
|
if hasattr(resp, "content"):
|
|
content = resp.content
|
|
if target == "model" and "model" in content:
|
|
self.room_models[room_id] = content["model"]
|
|
elif target == "rename" and content.get("enabled"):
|
|
self.auto_rename_rooms.add(room_id)
|
|
except Exception:
|
|
pass # State event doesn't exist yet
|
|
|
|
async def on_text_message(self, room, event: RoomMessageText):
|
|
"""Handle text messages: commands and AI responses."""
|
|
if event.sender == BOT_USER:
|
|
return
|
|
if not self._sync_token_received:
|
|
return # ignore messages from initial sync / backfill
|
|
# Ignore old messages (>30s) to avoid replaying history
|
|
server_ts = event.server_timestamp / 1000
|
|
if time.time() - server_ts > 30:
|
|
return
|
|
|
|
await self._load_room_settings(room.room_id)
|
|
body = event.body.strip()
|
|
|
|
# Command handling
|
|
if body.startswith("!ai "):
|
|
cmd = body[4:].strip()
|
|
await self._handle_command(room, cmd)
|
|
return
|
|
if body == "!ai":
|
|
await self._send_text(room.room_id, HELP_TEXT)
|
|
return
|
|
|
|
# In DMs (2 members), respond to all messages; in groups, require @mention
|
|
is_dm = room.member_count == 2
|
|
if not is_dm:
|
|
bot_display = self.client.user_id.split(":")[0].lstrip("@")
|
|
mentioned = (
|
|
BOT_USER in body
|
|
or f"@{bot_display}" in body.lower()
|
|
or bot_display.lower() in body.lower()
|
|
)
|
|
if not mentioned:
|
|
return
|
|
|
|
if not self.llm:
|
|
await self._send_text(room.room_id, "LLM not configured (LITELLM_BASE_URL not set).")
|
|
return
|
|
|
|
await self.client.room_typing(room.room_id, typing_state=True)
|
|
try:
|
|
await self._respond_with_ai(room, body)
|
|
finally:
|
|
await self.client.room_typing(room.room_id, typing_state=False)
|
|
|
|
async def _handle_command(self, room, cmd: str):
|
|
if cmd == "help":
|
|
await self._send_text(room.room_id, HELP_TEXT)
|
|
|
|
elif cmd == "models":
|
|
if not self.llm:
|
|
await self._send_text(room.room_id, "LLM not configured.")
|
|
return
|
|
try:
|
|
models = await self.llm.models.list()
|
|
names = sorted(m.id for m in models.data)
|
|
current = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
|
text = "**Available models:**\n"
|
|
text += "\n".join(f"- `{n}` {'← current' if n == current else ''}" for n in names)
|
|
await self._send_text(room.room_id, text)
|
|
except Exception:
|
|
logger.exception("Failed to list models")
|
|
await self._send_text(room.room_id, "Failed to fetch model list.")
|
|
|
|
elif cmd.startswith("set-model "):
|
|
model = cmd[10:].strip()
|
|
if not model:
|
|
await self._send_text(room.room_id, "Usage: `!ai set-model <model-name>`")
|
|
return
|
|
self.room_models[room.room_id] = model
|
|
# Persist in room state for cross-restart persistence
|
|
try:
|
|
await self.client.room_put_state(
|
|
room.room_id, MODEL_STATE_TYPE, {"model": model}, state_key="",
|
|
)
|
|
except Exception:
|
|
logger.debug("Could not persist model to room state", exc_info=True)
|
|
await self._send_text(room.room_id, f"Model set to `{model}` for this room.")
|
|
|
|
elif cmd.startswith("auto-rename "):
|
|
setting = cmd[12:].strip().lower()
|
|
if setting not in ("on", "off"):
|
|
await self._send_text(room.room_id, "Usage: `!ai auto-rename on|off`")
|
|
return
|
|
enabled = setting == "on"
|
|
if enabled:
|
|
self.auto_rename_rooms.add(room.room_id)
|
|
else:
|
|
self.auto_rename_rooms.discard(room.room_id)
|
|
try:
|
|
await self.client.room_put_state(
|
|
room.room_id, RENAME_STATE_TYPE,
|
|
{"enabled": enabled}, state_key="",
|
|
)
|
|
except Exception:
|
|
logger.debug("Could not persist auto-rename to room state", exc_info=True)
|
|
status = "enabled" if enabled else "disabled"
|
|
await self._send_text(room.room_id, f"Auto-rename **{status}** for this room.")
|
|
|
|
elif cmd.startswith("search "):
|
|
query = cmd[7:].strip()
|
|
if not query:
|
|
await self._send_text(room.room_id, "Usage: `!ai search <query>`")
|
|
return
|
|
results = await self.rag.search(query, top_k=5)
|
|
if not results:
|
|
await self._send_text(room.room_id, "No documents found.")
|
|
return
|
|
await self._send_text(room.room_id, self.rag.format_context(results))
|
|
|
|
else:
|
|
# Treat unknown commands as AI prompts
|
|
if self.llm:
|
|
await self.client.room_typing(room.room_id, typing_state=True)
|
|
try:
|
|
await self._respond_with_ai(room, cmd)
|
|
finally:
|
|
await self.client.room_typing(room.room_id, typing_state=False)
|
|
else:
|
|
await self._send_text(room.room_id, f"Unknown command: `{cmd}`\n\n{HELP_TEXT}")
|
|
|
|
async def _respond_with_ai(self, room, user_message: str):
|
|
model = self.room_models.get(room.room_id, DEFAULT_MODEL)
|
|
|
|
# Build conversation context from room timeline
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
|
|
|
# WildFiles document context
|
|
doc_results = await self.rag.search(user_message)
|
|
doc_context = self.rag.format_context(doc_results)
|
|
if doc_context:
|
|
logger.info("RAG found %d docs for: %s", len(doc_results), user_message[:50])
|
|
messages.append({"role": "system", "content": doc_context})
|
|
else:
|
|
logger.info("RAG found 0 docs for: %s", user_message[:50])
|
|
|
|
# Fetch last N messages from room via API
|
|
try:
|
|
resp = await self.client.room_messages(
|
|
room.room_id, start=self.client.next_batch or "", limit=10
|
|
)
|
|
if hasattr(resp, "chunk"):
|
|
for evt in reversed(resp.chunk):
|
|
if not hasattr(evt, "body"):
|
|
continue
|
|
role = "assistant" if evt.sender == BOT_USER else "user"
|
|
messages.append({"role": role, "content": evt.body})
|
|
except Exception:
|
|
logger.debug("Could not fetch room history, proceeding without context")
|
|
|
|
# Add current user message
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
try:
|
|
resp = await self.llm.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
)
|
|
reply = resp.choices[0].message.content
|
|
await self._send_text(room.room_id, reply)
|
|
# Auto-rename room after first AI response
|
|
if (room.room_id in self.auto_rename_rooms
|
|
and room.room_id not in self.renamed_rooms):
|
|
await self._auto_rename_room(room, user_message, reply)
|
|
except Exception:
|
|
logger.exception("LLM call failed")
|
|
await self._send_text(room.room_id, "Sorry, I couldn't generate a response.")
|
|
|
|
async def _auto_rename_room(self, room, user_message: str, ai_reply: str):
|
|
"""Generate a short topic title and set it as the room name."""
|
|
try:
|
|
resp = await self.llm.chat.completions.create(
|
|
model=self.room_models.get(room.room_id, DEFAULT_MODEL),
|
|
messages=[
|
|
{"role": "system", "content": (
|
|
"Generate a very short room title (3-6 words, no quotes) "
|
|
"that captures the topic of this conversation. "
|
|
"Reply with ONLY the title, nothing else."
|
|
)},
|
|
{"role": "user", "content": user_message},
|
|
{"role": "assistant", "content": ai_reply[:200]},
|
|
{"role": "user", "content": "What is a good short title for this conversation?"},
|
|
],
|
|
max_tokens=30,
|
|
)
|
|
title = resp.choices[0].message.content.strip().strip('"\'')
|
|
if not title or len(title) > 80:
|
|
return
|
|
await self.client.room_put_state(
|
|
room.room_id, "m.room.name",
|
|
{"name": title}, state_key="",
|
|
)
|
|
self.renamed_rooms.add(room.room_id)
|
|
logger.info("Auto-renamed room %s to: %s", room.room_id, title)
|
|
except Exception:
|
|
logger.debug("Auto-rename failed", exc_info=True)
|
|
|
|
@staticmethod
|
|
def _md_to_html(text: str) -> str:
|
|
"""Minimal markdown to HTML for Matrix formatted_body."""
|
|
import html as html_mod
|
|
safe = html_mod.escape(text)
|
|
# Code blocks (```...```)
|
|
safe = re.sub(r"```(\w*)\n(.*?)```", r"<pre><code>\2</code></pre>", safe, flags=re.DOTALL)
|
|
# Inline code
|
|
safe = re.sub(r"`([^`]+)`", r"<code>\1</code>", safe)
|
|
# Bold
|
|
safe = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", safe)
|
|
# Italic
|
|
safe = re.sub(r"\*(.+?)\*", r"<em>\1</em>", safe)
|
|
# Line breaks
|
|
safe = safe.replace("\n", "<br/>")
|
|
return safe
|
|
|
|
async def _send_text(self, room_id: str, text: str):
|
|
await self.client.room_send(
|
|
room_id,
|
|
message_type="m.room.message",
|
|
content={
|
|
"msgtype": "m.text",
|
|
"body": text,
|
|
"format": "org.matrix.custom.html",
|
|
"formatted_body": self._md_to_html(text),
|
|
},
|
|
)
|
|
|
|
async def _route_verification(self, room, event: UnknownEvent):
|
|
"""Route in-room verification events from UnknownEvent."""
|
|
source = event.source or {}
|
|
verify_type = event.type
|
|
logger.info("Verification event: %s from %s", verify_type, event.sender)
|
|
|
|
if verify_type == "m.key.verification.request":
|
|
await self._handle_verification_request(room, source)
|
|
elif verify_type == "m.key.verification.start":
|
|
await self._handle_verification_start(room, source)
|
|
elif verify_type == "m.key.verification.key":
|
|
await self._handle_verification_key(room, source)
|
|
elif verify_type == "m.key.verification.mac":
|
|
await self._handle_verification_mac(room, source)
|
|
elif verify_type == "m.key.verification.cancel":
|
|
content = source.get("content", {})
|
|
txn = content.get("m.relates_to", {}).get("event_id", "")
|
|
self._verifications.pop(txn, None)
|
|
logger.info("Verification cancelled: %s", txn)
|
|
elif verify_type == "m.key.verification.done":
|
|
pass # Other side confirmed done
|
|
|
|
async def on_room_unknown(self, room, event: RoomMessageUnknown):
|
|
"""Handle in-room verification events."""
|
|
source = event.source or {}
|
|
content = source.get("content", {})
|
|
event_type = source.get("type", "")
|
|
msgtype = content.get("msgtype", "")
|
|
|
|
logger.info("RoomMessageUnknown: type=%s msgtype=%s sender=%s", event_type, msgtype, event.sender)
|
|
|
|
# In-room verification events can come as m.room.message with msgtype=m.key.verification.*
|
|
# or as direct event types m.key.verification.*
|
|
verify_type = ""
|
|
if event_type.startswith("m.key.verification."):
|
|
verify_type = event_type
|
|
elif msgtype.startswith("m.key.verification."):
|
|
verify_type = msgtype
|
|
|
|
if not verify_type:
|
|
return
|
|
if event.sender == BOT_USER:
|
|
return
|
|
|
|
logger.info("Verification event: %s from %s", verify_type, event.sender)
|
|
|
|
if verify_type == "m.key.verification.request":
|
|
await self._handle_verification_request(room, source)
|
|
elif verify_type == "m.key.verification.start":
|
|
await self._handle_verification_start(room, source)
|
|
elif verify_type == "m.key.verification.key":
|
|
await self._handle_verification_key(room, source)
|
|
elif verify_type == "m.key.verification.mac":
|
|
await self._handle_verification_mac(room, source)
|
|
elif verify_type == "m.key.verification.cancel":
|
|
txn = content.get("m.relates_to", {}).get("event_id", "")
|
|
self._verifications.pop(txn, None)
|
|
logger.info("Verification cancelled: %s", txn)
|
|
|
|
async def _handle_verification_request(self, room, source):
|
|
content = source["content"]
|
|
txn_id = source["event_id"]
|
|
sender = source["sender"]
|
|
|
|
self._verifications[txn_id] = {"sender": sender, "room_id": room.room_id}
|
|
logger.info("Verification request from %s, txn=%s", sender, txn_id)
|
|
|
|
# Send m.key.verification.ready
|
|
await self.client.room_send(
|
|
room.room_id,
|
|
message_type="m.key.verification.ready",
|
|
content={
|
|
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
|
|
"from_device": self.client.device_id,
|
|
"methods": ["m.sas.v1"],
|
|
},
|
|
)
|
|
logger.info("Sent verification ready for %s", txn_id)
|
|
|
|
async def _handle_verification_start(self, room, source):
|
|
content = source["content"]
|
|
txn_id = content.get("m.relates_to", {}).get("event_id", "")
|
|
v = self._verifications.get(txn_id)
|
|
if not v:
|
|
logger.warning("Unknown verification start: %s", txn_id)
|
|
return
|
|
|
|
sas_obj = olm_sas.Sas()
|
|
v["sas"] = sas_obj
|
|
v["commitment"] = content.get("commitment", "")
|
|
|
|
# Send m.key.verification.accept is NOT needed when we sent "ready"
|
|
# and the other side sent "start". We go straight to sending our key.
|
|
|
|
# Send m.key.verification.key
|
|
await self.client.room_send(
|
|
room.room_id,
|
|
message_type="m.key.verification.key",
|
|
content={
|
|
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
|
|
"key": sas_obj.pubkey,
|
|
},
|
|
)
|
|
v["key_sent"] = True
|
|
logger.info("Sent SAS key for %s", txn_id)
|
|
|
|
async def _handle_verification_key(self, room, source):
|
|
content = source["content"]
|
|
txn_id = content.get("m.relates_to", {}).get("event_id", "")
|
|
v = self._verifications.get(txn_id)
|
|
if not v or "sas" not in v:
|
|
logger.warning("Unknown verification key: %s", txn_id)
|
|
return
|
|
|
|
sas_obj = v["sas"]
|
|
their_key = content["key"]
|
|
sas_obj.set_their_pubkey(their_key)
|
|
v["their_key"] = their_key
|
|
|
|
# Auto-confirm SAS (bot trusts the user)
|
|
# Generate MAC for our device key and master key
|
|
our_user = BOT_USER
|
|
our_device = self.client.device_id
|
|
their_user = v["sender"]
|
|
|
|
# Key IDs to MAC
|
|
key_id = f"ed25519:{our_device}"
|
|
device_key = self.client.olm.account.identity_keys["ed25519"]
|
|
|
|
# MAC info strings per spec
|
|
base_info = (
|
|
f"MATRIX_KEY_VERIFICATION_MAC"
|
|
f"{our_user}{our_device}"
|
|
f"{their_user}{content.get('from_device', '')}"
|
|
f"{txn_id}"
|
|
)
|
|
|
|
mac_dict = {}
|
|
keys_list = []
|
|
|
|
# MAC our ed25519 device key
|
|
mac_dict[key_id] = sas_obj.calculate_mac(device_key, base_info + key_id)
|
|
keys_list.append(key_id)
|
|
|
|
# MAC the key list
|
|
keys_str = ",".join(sorted(keys_list))
|
|
keys_mac = sas_obj.calculate_mac(keys_str, base_info + "KEY_IDS")
|
|
|
|
await self.client.room_send(
|
|
room.room_id,
|
|
message_type="m.key.verification.mac",
|
|
content={
|
|
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
|
|
"mac": mac_dict,
|
|
"keys": keys_mac,
|
|
},
|
|
)
|
|
logger.info("Sent SAS MAC for %s", txn_id)
|
|
|
|
async def _handle_verification_mac(self, room, source):
|
|
content = source["content"]
|
|
txn_id = content.get("m.relates_to", {}).get("event_id", "")
|
|
v = self._verifications.get(txn_id)
|
|
if not v:
|
|
return
|
|
|
|
# Verification complete — send done
|
|
await self.client.room_send(
|
|
room.room_id,
|
|
message_type="m.key.verification.done",
|
|
content={
|
|
"m.relates_to": {"rel_type": "m.reference", "event_id": txn_id},
|
|
},
|
|
)
|
|
logger.info("Verification complete for %s with %s", txn_id, v["sender"])
|
|
self._verifications.pop(txn_id, None)
|
|
|
|
async def on_megolm(self, room, event: MegolmEvent):
|
|
"""Request keys for undecryptable messages."""
|
|
logger.warning(
|
|
"Undecryptable event %s in %s from %s — requesting keys",
|
|
event.event_id, room.room_id, event.sender,
|
|
)
|
|
try:
|
|
await self.client.request_room_key(event)
|
|
except Exception:
|
|
logger.debug("Key request failed", exc_info=True)
|
|
|
|
async def on_key_verification(self, event):
|
|
"""Auto-accept key verification requests."""
|
|
if isinstance(event, KeyVerificationStart):
|
|
sas = self.client.key_verifications.get(event.transaction_id)
|
|
if sas:
|
|
await self.client.accept_key_verification(event.transaction_id)
|
|
await self.client.to_device(sas.share_key())
|
|
elif isinstance(event, KeyVerificationKey):
|
|
sas = self.client.key_verifications.get(event.transaction_id)
|
|
if sas:
|
|
await self.client.confirm_short_auth_string(event.transaction_id)
|
|
elif isinstance(event, KeyVerificationMac):
|
|
sas = self.client.key_verifications.get(event.transaction_id)
|
|
if sas:
|
|
mac = sas.get_mac()
|
|
if not isinstance(mac, ToDeviceError):
|
|
await self.client.to_device(mac)
|
|
|
|
async def cleanup(self):
|
|
await self.client.close()
|
|
if self.lkapi:
|
|
await self.lkapi.aclose()
|
|
|
|
|
|
async def main():
|
|
os.makedirs(STORE_PATH, exist_ok=True)
|
|
bot = Bot()
|
|
try:
|
|
await bot.start()
|
|
finally:
|
|
await bot.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
asyncio.run(main())
|