fix: handle Element X MSC4143 v2 encryption key format (memberships array)

Element X embeds E2EE keys inside memberships[].encryption_keys,
not at the top level of the call.member state event content.
Bot was only checking content.encryption_keys, so it never found
the caller's key — causing 'Warten auf Medien' (waiting for media)
because encrypted audio couldn't be decrypted.

- Added _extract_enc_keys_from_content() helper handling both formats
- Updated on_unknown handler, VoiceSession creation, and key fetch
- Bot now publishes keys in both formats for compatibility
- Updated voice.py state fetch to check memberships[] fallback

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Christian Gick
2026-03-24 08:57:24 +02:00
parent 3363b4238f
commit c11dd73ce3
2 changed files with 81 additions and 36 deletions

103
bot.py
View File

@@ -51,6 +51,32 @@ from cross_signing import CrossSigningManager
BOT_DEVICE_ID = "AIBOT"
CALL_MEMBER_TYPE = "org.matrix.msc3401.call.member"
ENCRYPTION_KEYS_TYPE = "io.element.call.encryption_keys"
def _extract_enc_keys_from_content(content: dict, skip_device: str = "") -> list[dict]:
"""Extract encryption_keys from call.member content, handling both formats.
Format 1 (legacy): content.encryption_keys = [{key, index}, ...]
Format 2 (MSC4143 v2 / Element X): content.memberships[].encryption_keys = [{key, index}, ...]
Returns list of {key, index, device_id} dicts.
"""
results = []
# Format 1: top-level encryption_keys
top_keys = content.get("encryption_keys", content.get("keys", []))
if top_keys:
device = content.get("device_id", "")
if device != skip_device:
for k in top_keys:
results.append({"key": k.get("key", ""), "index": k.get("index", 0), "device_id": device})
# Format 2: memberships array (Element X / MSC4143 v2)
for m in content.get("memberships", []):
device = m.get("device_id", "")
if device == skip_device:
continue
for k in m.get("encryption_keys", []):
results.append({"key": k.get("key", ""), "index": k.get("index", 0), "device_id": device})
return [r for r in results if r["key"]]
MODEL_STATE_TYPE = "ai.agiliton.model"
RENAME_STATE_TYPE = "ai.agiliton.auto_rename"
@@ -1564,21 +1590,20 @@ class Bot:
# MSC4143: encryption keys may be embedded in call.member state events
# Check for keys BEFORE the early return for active calls
enc_keys = content.get("encryption_keys", content.get("keys", []))
if enc_keys and room_id in self.voice_sessions:
device_id = content.get("device_id", "")
logger.info("Found encryption_keys in call.member event from %s (device=%s, keys=%d)",
event.sender, device_id, len(enc_keys))
# Handles both top-level encryption_keys and memberships[].encryption_keys (Element X)
enc_key_entries = _extract_enc_keys_from_content(content, skip_device=BOT_DEVICE_ID)
if enc_key_entries and room_id in self.voice_sessions:
logger.info("Found %d encryption key(s) in call.member event from %s",
len(enc_key_entries), event.sender)
vs = self.voice_sessions[room_id]
import base64 as b64
for k in enc_keys:
key_b64 = k.get("key", "")
key_index = k.get("index", 0)
if key_b64:
key_b64 += "=" * (-len(key_b64) % 4)
key_bytes = b64.urlsafe_b64decode(key_b64)
vs.on_encryption_key(event.sender, device_id, key_bytes, key_index)
logger.info("Delivered call.member embedded key (index=%d) to voice session", key_index)
for entry in enc_key_entries:
key_b64 = entry["key"]
key_b64 += "=" * (-len(key_b64) % 4)
key_bytes = b64.urlsafe_b64decode(key_b64)
vs.on_encryption_key(event.sender, entry["device_id"], key_bytes, entry["index"])
logger.info("Delivered call.member embedded key (index=%d, device=%s) to voice session",
entry["index"], entry["device_id"])
if room_id in self.active_calls:
return
@@ -1621,7 +1646,15 @@ class Bot:
},
"foci_preferred": foci,
"m.call.intent": "audio",
# Publish keys in both formats for compatibility
"encryption_keys": [{"index": 0, "key": bot_key_b64}],
"memberships": [{
"device_id": BOT_DEVICE_ID,
"encryption_keys": [{"index": 0, "key": bot_key_b64}],
"expires": 7200000,
"foci_preferred": foci,
"focus_active": {"type": "livekit", "focus_selection": "oldest_membership"},
}],
}
state_key = f"_{BOT_USER}_{BOT_DEVICE_ID}_m.call"
@@ -1670,18 +1703,18 @@ class Bot:
# Check for caller's key: first in call_member event (MSC4143),
# then fall back to timeline scan (legacy io.element.call.encryption_keys)
# Handles both top-level and memberships[] format (Element X)
caller_key = None
caller_enc_keys = content.get("encryption_keys", content.get("keys", []))
if caller_enc_keys:
caller_enc_entries = _extract_enc_keys_from_content(content, skip_device=BOT_DEVICE_ID)
if caller_enc_entries:
import base64 as b64
for k in caller_enc_keys:
key_b64 = k.get("key", "")
key_index = k.get("index", 0)
if key_b64:
key_b64 += "=" * (-len(key_b64) % 4)
caller_key = b64.urlsafe_b64decode(key_b64)
vs.on_encryption_key(event.sender, caller_device_id, caller_key, key_index)
logger.info("Got caller E2EE key from call.member event (index=%d)", key_index)
for entry in caller_enc_entries:
key_b64 = entry["key"]
key_b64 += "=" * (-len(key_b64) % 4)
caller_key = b64.urlsafe_b64decode(key_b64)
vs.on_encryption_key(event.sender, entry["device_id"], caller_key, entry["index"])
logger.info("Got caller E2EE key from call.member event (index=%d, device=%s)",
entry["index"], entry["device_id"])
if not caller_key:
caller_key = await self._get_call_encryption_key(room_id, event.sender, caller_device_id)
@@ -3537,6 +3570,7 @@ class Bot:
token = self.client.access_token
# MSC4143: Check call.member state events for embedded encryption_keys
# Handles both top-level and memberships[] format (Element X)
state_url = f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/state"
async with httpx.AsyncClient(timeout=10.0) as http:
resp = await http.get(state_url, headers={"Authorization": f"Bearer {token}"})
@@ -3549,18 +3583,17 @@ class Bot:
if evt_sender == BOT_USER:
continue
content = evt.get("content", {})
enc_keys = content.get("encryption_keys", [])
if enc_keys:
device = content.get("device_id", "")
logger.info("Found encryption_keys in call.member state from %s (device=%s, keys=%d)",
evt_sender, device, len(enc_keys))
for k in enc_keys:
key_b64 = k.get("key", "")
if key_b64:
key_b64 += "=" * (-len(key_b64) % 4)
key = base64.urlsafe_b64decode(key_b64)
logger.info("Got E2EE key from call.member state (%s, %d bytes)", evt_sender, len(key))
return key
entries = _extract_enc_keys_from_content(content, skip_device=BOT_DEVICE_ID)
if entries:
logger.info("Found %d encryption key(s) in call.member state from %s",
len(entries), evt_sender)
entry = entries[0]
key_b64 = entry["key"]
key_b64 += "=" * (-len(key_b64) % 4)
key = base64.urlsafe_b64decode(key_b64)
logger.info("Got E2EE key from call.member state (%s, device=%s, %d bytes)",
evt_sender, entry["device_id"], len(key))
return key
# Legacy: scan timeline for io.element.call.encryption_keys events
async with httpx.AsyncClient(timeout=10.0) as http:

View File

@@ -554,6 +554,7 @@ class VoiceSession:
user_id = self.nio_client.user_id
# MSC4143: check call.member state events first
# Handles both top-level encryption_keys and memberships[].encryption_keys (Element X)
try:
state_url = f"{homeserver}/_matrix/client/v3/rooms/{self.room_id}/state"
async with httpx.AsyncClient(timeout=10.0) as http:
@@ -566,7 +567,17 @@ class VoiceSession:
if sender == user_id:
continue
content = evt.get("content", {})
# Extract keys from both formats
enc_keys = content.get("encryption_keys", [])
# MSC4143 v2 / Element X: keys inside memberships array
if not enc_keys:
for m in content.get("memberships", []):
m_keys = m.get("encryption_keys", [])
if m_keys:
enc_keys = m_keys
# Use device_id from membership entry
content = {**content, "device_id": m.get("device_id", content.get("device_id", ""))}
break
if enc_keys:
device = content.get("device_id", "")
import base64 as b64
@@ -581,7 +592,8 @@ class VoiceSession:
self.on_encryption_key(sender, device, key_bytes, key_index)
max_idx = max(self._caller_all_keys.keys()) if self._caller_all_keys else key_index
latest = self._caller_all_keys.get(max_idx, key_bytes)
logger.info("Got key from call.member state (sender=%s, index=%d)", sender, key_index)
logger.info("Got key from call.member state (sender=%s, device=%s, index=%d)",
sender, device, key_index)
return latest
except Exception as e:
logger.debug("call.member state key fetch failed: %s", e)