diff --git a/bot.py b/bot.py index 823d75a..d61237d 100644 --- a/bot.py +++ b/bot.py @@ -51,6 +51,32 @@ from cross_signing import CrossSigningManager BOT_DEVICE_ID = "AIBOT" CALL_MEMBER_TYPE = "org.matrix.msc3401.call.member" ENCRYPTION_KEYS_TYPE = "io.element.call.encryption_keys" + + +def _extract_enc_keys_from_content(content: dict, skip_device: str = "") -> list[dict]: + """Extract encryption_keys from call.member content, handling both formats. + + Format 1 (legacy): content.encryption_keys = [{key, index}, ...] + Format 2 (MSC4143 v2 / Element X): content.memberships[].encryption_keys = [{key, index}, ...] + + Returns list of {key, index, device_id} dicts. + """ + results = [] + # Format 1: top-level encryption_keys + top_keys = content.get("encryption_keys", content.get("keys", [])) + if top_keys: + device = content.get("device_id", "") + if device != skip_device: + for k in top_keys: + results.append({"key": k.get("key", ""), "index": k.get("index", 0), "device_id": device}) + # Format 2: memberships array (Element X / MSC4143 v2) + for m in content.get("memberships", []): + device = m.get("device_id", "") + if device == skip_device: + continue + for k in m.get("encryption_keys", []): + results.append({"key": k.get("key", ""), "index": k.get("index", 0), "device_id": device}) + return [r for r in results if r["key"]] MODEL_STATE_TYPE = "ai.agiliton.model" RENAME_STATE_TYPE = "ai.agiliton.auto_rename" @@ -1564,21 +1590,20 @@ class Bot: # MSC4143: encryption keys may be embedded in call.member state events # Check for keys BEFORE the early return for active calls - enc_keys = content.get("encryption_keys", content.get("keys", [])) - if enc_keys and room_id in self.voice_sessions: - device_id = content.get("device_id", "") - logger.info("Found encryption_keys in call.member event from %s (device=%s, keys=%d)", - event.sender, device_id, len(enc_keys)) + # Handles both top-level encryption_keys and memberships[].encryption_keys (Element X) + enc_key_entries = _extract_enc_keys_from_content(content, skip_device=BOT_DEVICE_ID) + if enc_key_entries and room_id in self.voice_sessions: + logger.info("Found %d encryption key(s) in call.member event from %s", + len(enc_key_entries), event.sender) vs = self.voice_sessions[room_id] import base64 as b64 - for k in enc_keys: - key_b64 = k.get("key", "") - key_index = k.get("index", 0) - if key_b64: - key_b64 += "=" * (-len(key_b64) % 4) - key_bytes = b64.urlsafe_b64decode(key_b64) - vs.on_encryption_key(event.sender, device_id, key_bytes, key_index) - logger.info("Delivered call.member embedded key (index=%d) to voice session", key_index) + for entry in enc_key_entries: + key_b64 = entry["key"] + key_b64 += "=" * (-len(key_b64) % 4) + key_bytes = b64.urlsafe_b64decode(key_b64) + vs.on_encryption_key(event.sender, entry["device_id"], key_bytes, entry["index"]) + logger.info("Delivered call.member embedded key (index=%d, device=%s) to voice session", + entry["index"], entry["device_id"]) if room_id in self.active_calls: return @@ -1621,7 +1646,15 @@ class Bot: }, "foci_preferred": foci, "m.call.intent": "audio", + # Publish keys in both formats for compatibility "encryption_keys": [{"index": 0, "key": bot_key_b64}], + "memberships": [{ + "device_id": BOT_DEVICE_ID, + "encryption_keys": [{"index": 0, "key": bot_key_b64}], + "expires": 7200000, + "foci_preferred": foci, + "focus_active": {"type": "livekit", "focus_selection": "oldest_membership"}, + }], } state_key = f"_{BOT_USER}_{BOT_DEVICE_ID}_m.call" @@ -1670,18 +1703,18 @@ class Bot: # Check for caller's key: first in call_member event (MSC4143), # then fall back to timeline scan (legacy io.element.call.encryption_keys) + # Handles both top-level and memberships[] format (Element X) caller_key = None - caller_enc_keys = content.get("encryption_keys", content.get("keys", [])) - if caller_enc_keys: + caller_enc_entries = _extract_enc_keys_from_content(content, skip_device=BOT_DEVICE_ID) + if caller_enc_entries: import base64 as b64 - for k in caller_enc_keys: - key_b64 = k.get("key", "") - key_index = k.get("index", 0) - if key_b64: - key_b64 += "=" * (-len(key_b64) % 4) - caller_key = b64.urlsafe_b64decode(key_b64) - vs.on_encryption_key(event.sender, caller_device_id, caller_key, key_index) - logger.info("Got caller E2EE key from call.member event (index=%d)", key_index) + for entry in caller_enc_entries: + key_b64 = entry["key"] + key_b64 += "=" * (-len(key_b64) % 4) + caller_key = b64.urlsafe_b64decode(key_b64) + vs.on_encryption_key(event.sender, entry["device_id"], caller_key, entry["index"]) + logger.info("Got caller E2EE key from call.member event (index=%d, device=%s)", + entry["index"], entry["device_id"]) if not caller_key: caller_key = await self._get_call_encryption_key(room_id, event.sender, caller_device_id) @@ -3537,6 +3570,7 @@ class Bot: token = self.client.access_token # MSC4143: Check call.member state events for embedded encryption_keys + # Handles both top-level and memberships[] format (Element X) state_url = f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/state" async with httpx.AsyncClient(timeout=10.0) as http: resp = await http.get(state_url, headers={"Authorization": f"Bearer {token}"}) @@ -3549,18 +3583,17 @@ class Bot: if evt_sender == BOT_USER: continue content = evt.get("content", {}) - enc_keys = content.get("encryption_keys", []) - if enc_keys: - device = content.get("device_id", "") - logger.info("Found encryption_keys in call.member state from %s (device=%s, keys=%d)", - evt_sender, device, len(enc_keys)) - for k in enc_keys: - key_b64 = k.get("key", "") - if key_b64: - key_b64 += "=" * (-len(key_b64) % 4) - key = base64.urlsafe_b64decode(key_b64) - logger.info("Got E2EE key from call.member state (%s, %d bytes)", evt_sender, len(key)) - return key + entries = _extract_enc_keys_from_content(content, skip_device=BOT_DEVICE_ID) + if entries: + logger.info("Found %d encryption key(s) in call.member state from %s", + len(entries), evt_sender) + entry = entries[0] + key_b64 = entry["key"] + key_b64 += "=" * (-len(key_b64) % 4) + key = base64.urlsafe_b64decode(key_b64) + logger.info("Got E2EE key from call.member state (%s, device=%s, %d bytes)", + evt_sender, entry["device_id"], len(key)) + return key # Legacy: scan timeline for io.element.call.encryption_keys events async with httpx.AsyncClient(timeout=10.0) as http: diff --git a/voice.py b/voice.py index 2050762..8b38fd2 100644 --- a/voice.py +++ b/voice.py @@ -554,6 +554,7 @@ class VoiceSession: user_id = self.nio_client.user_id # MSC4143: check call.member state events first + # Handles both top-level encryption_keys and memberships[].encryption_keys (Element X) try: state_url = f"{homeserver}/_matrix/client/v3/rooms/{self.room_id}/state" async with httpx.AsyncClient(timeout=10.0) as http: @@ -566,7 +567,17 @@ class VoiceSession: if sender == user_id: continue content = evt.get("content", {}) + # Extract keys from both formats enc_keys = content.get("encryption_keys", []) + # MSC4143 v2 / Element X: keys inside memberships array + if not enc_keys: + for m in content.get("memberships", []): + m_keys = m.get("encryption_keys", []) + if m_keys: + enc_keys = m_keys + # Use device_id from membership entry + content = {**content, "device_id": m.get("device_id", content.get("device_id", ""))} + break if enc_keys: device = content.get("device_id", "") import base64 as b64 @@ -581,7 +592,8 @@ class VoiceSession: self.on_encryption_key(sender, device, key_bytes, key_index) max_idx = max(self._caller_all_keys.keys()) if self._caller_all_keys else key_index latest = self._caller_all_keys.get(max_idx, key_bytes) - logger.info("Got key from call.member state (sender=%s, index=%d)", sender, key_index) + logger.info("Got key from call.member state (sender=%s, device=%s, index=%d)", + sender, device, key_index) return latest except Exception as e: logger.debug("call.member state key fetch failed: %s", e)