fix(e2ee): fix screen share key rotation failures (MAT-164)

- Route HTTP-fetched keys through on_encryption_key() for proper rotation detection
- Replace boolean refetch gate with 500ms timestamp throttle for faster recovery
- Reduce DEC_FAILED cooldown from 2s to 0.5s
- Extend proactive key poll from 3s to 10s window
- Add continuous background key poller (3s interval) during active calls

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Christian Gick
2026-03-19 12:35:50 +02:00
parent 183c41a72b
commit c29c2170f3

View File

@@ -488,6 +488,7 @@ class VoiceSession:
self._document_context = document_context # PDF text from room for voice context self._document_context = document_context # PDF text from room for voice context
self._transcript: list[dict] = [] # {"role": "user"|"assistant", "text": "..."} self._transcript: list[dict] = [] # {"role": "user"|"assistant", "text": "..."}
self._video_track: rtc.Track | None = None # remote video track for on-demand vision self._video_track: rtc.Track | None = None # remote video track for on-demand vision
self._key_poller_task: asyncio.Task | None = None # continuous background key poller (MAT-164)
def on_encryption_key(self, sender, device_id, key, index): def on_encryption_key(self, sender, device_id, key, index):
"""Receive E2EE key from Element Call participant. """Receive E2EE key from Element Call participant.
@@ -590,17 +591,51 @@ class VoiceSession:
if all_keys: if all_keys:
if device: if device:
self._caller_identity = f"{sender}:{device}" self._caller_identity = f"{sender}:{device}"
self._caller_all_keys.update(all_keys) # Route each key through on_encryption_key() so
max_idx = max(all_keys.keys()) # same-index rotation detection works correctly
# (MAT-164: was bypassing rotation via direct update).
for key_index, key_bytes in sorted(all_keys.items()):
self.on_encryption_key(sender, device, key_bytes, key_index)
max_idx = max(self._caller_all_keys.keys()) if self._caller_all_keys else max(all_keys.keys())
latest_key = self._caller_all_keys.get(max_idx, all_keys[max(all_keys.keys())])
logger.info("Loaded caller keys at indices %s (using %d, key=%s)", logger.info("Loaded caller keys at indices %s (using %d, key=%s)",
sorted(all_keys.keys()), max_idx, sorted(self._caller_all_keys.keys()), max_idx,
all_keys[max_idx].hex()[:8]) latest_key.hex()[:8])
return all_keys[max_idx] return latest_key
logger.info("No encryption_keys events in last %d timeline events", len(events)) logger.info("No encryption_keys events in last %d timeline events", len(events))
except Exception as e: except Exception as e:
logger.warning("HTTP encryption key fetch failed: %s", e) logger.warning("HTTP encryption key fetch failed: %s", e)
return None return None
async def _continuous_key_poller(self):
"""Background poll for E2EE key rotations every 3s while connected (MAT-164).
Catches delayed key rotations that the proactive poll and sync path miss.
Compares key fingerprint and routes new keys through on_encryption_key().
"""
logger.info("Background key poller started")
try:
while self.lk_room:
await asyncio.sleep(3.0)
if not self.lk_room:
break
try:
prev_fingerprint = self._caller_key.hex() if self._caller_key else None
new_key = await self._fetch_encryption_key_http()
if new_key and (not self._caller_key or new_key != self._caller_key):
logger.info("Background poll: key rotated (%s -> %s)",
prev_fingerprint[:8] if prev_fingerprint else "none",
new_key.hex()[:8])
# Route through on_encryption_key for proper rotation handling
sender = self._caller_identity.split(":")[0] if self._caller_identity else ""
device = self._caller_identity.split(":")[-1] if self._caller_identity else ""
self.on_encryption_key(sender, device, new_key, 0)
except Exception as exc:
logger.debug("Background key poll error: %s", exc)
except asyncio.CancelledError:
pass
logger.info("Background key poller stopped")
async def start(self): async def start(self):
self._task = asyncio.create_task(self._run()) self._task = asyncio.create_task(self._run())
@@ -616,6 +651,8 @@ class VoiceSession:
await obj.close() await obj.close()
except Exception: except Exception:
pass pass
if self._key_poller_task and not self._key_poller_task.done():
self._key_poller_task.cancel()
if self._activity_video: if self._activity_video:
self._activity_video.stop() self._activity_video.stop()
if self._activity_task and not self._activity_task.done(): if self._activity_task and not self._activity_task.done():
@@ -714,7 +751,7 @@ class VoiceSession:
# for DEC_FAILED (MAT-164). # for DEC_FAILED (MAT-164).
async def _proactive_key_poll(pid=p.identity): async def _proactive_key_poll(pid=p.identity):
pre_key = self._caller_key pre_key = self._caller_key
for attempt in range(6): # 6 × 500ms = 3s for attempt in range(20): # 20 × 500ms = 10s (MAT-164)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
if self._caller_key != pre_key: if self._caller_key != pre_key:
logger.info("Proactive poll: key rotated via sync (attempt %d)", attempt + 1) logger.info("Proactive poll: key rotated via sync (attempt %d)", attempt + 1)
@@ -728,7 +765,7 @@ class VoiceSession:
self._caller_identity.split(":")[-1] if self._caller_identity else "", self._caller_identity.split(":")[-1] if self._caller_identity else "",
new_key, 0) new_key, 0)
return return
logger.info("Proactive poll: no key rotation after 3s") logger.info("Proactive poll: no key rotation after 10s")
asyncio.ensure_future(_proactive_key_poll()) asyncio.ensure_future(_proactive_key_poll())
if int(t.kind) in (1, 2) and e2ee_opts is not None: # audio + video tracks if int(t.kind) in (1, 2) and e2ee_opts is not None: # audio + video tracks
caller_id = p.identity caller_id = p.identity
@@ -761,10 +798,10 @@ class VoiceSession:
_e2ee_state_names = {0:"NEW",1:"OK",2:"ENC_FAILED",3:"DEC_FAILED",4:"MISSING_KEY",5:"RATCHETED",6:"INTERNAL_ERR"} _e2ee_state_names = {0:"NEW",1:"OK",2:"ENC_FAILED",3:"DEC_FAILED",4:"MISSING_KEY",5:"RATCHETED",6:"INTERNAL_ERR"}
_last_rekey_time = {} # per-participant cooldown for DEC_FAILED re-keying _last_rekey_time = {} # per-participant cooldown for DEC_FAILED re-keying
_dec_failed_count = {} # consecutive DEC_FAILED per participant _dec_failed_count = {} # consecutive DEC_FAILED per participant
_refetch_in_progress = False _last_refetch_time = 0.0 # timestamp throttle replaces boolean gate (MAT-164)
@self.lk_room.on("e2ee_state_changed") @self.lk_room.on("e2ee_state_changed")
def on_e2ee_state(participant, state): def on_e2ee_state(participant, state):
nonlocal _refetch_in_progress nonlocal _last_refetch_time
state_name = _e2ee_state_names.get(int(state), f"UNKNOWN_{state}") state_name = _e2ee_state_names.get(int(state), f"UNKNOWN_{state}")
p_id = participant.identity if participant else "local" p_id = participant.identity if participant else "local"
logger.info("E2EE_STATE: participant=%s state=%s", p_id, state_name) logger.info("E2EE_STATE: participant=%s state=%s", p_id, state_name)
@@ -776,11 +813,11 @@ class VoiceSession:
if int(state) == 3: if int(state) == 3:
_dec_failed_count[p_id] = _dec_failed_count.get(p_id, 0) + 1 _dec_failed_count[p_id] = _dec_failed_count.get(p_id, 0) + 1
# After 1+ DEC_FAILED: re-fetch key from timeline (key may have rotated) # After 1+ DEC_FAILED: re-fetch key from timeline (key may have rotated)
if _dec_failed_count[p_id] >= 1 and not _refetch_in_progress: # Timestamp throttle: allow re-fetch if >500ms since last (MAT-164)
_refetch_in_progress = True if _dec_failed_count[p_id] >= 1 and (now - _last_refetch_time) > 0.5:
_last_refetch_time = now
_p_id_copy = p_id # capture for closure _p_id_copy = p_id # capture for closure
async def _refetch_key(): async def _refetch_key():
nonlocal _refetch_in_progress
try: try:
logger.info("DEC_FAILED x%d — re-fetching key from timeline", logger.info("DEC_FAILED x%d — re-fetching key from timeline",
_dec_failed_count.get(_p_id_copy, 0)) _dec_failed_count.get(_p_id_copy, 0))
@@ -799,12 +836,10 @@ class VoiceSession:
logger.info("Re-fetch returned no fresh key") logger.info("Re-fetch returned no fresh key")
except Exception as exc: except Exception as exc:
logger.warning("Key re-fetch failed: %s", exc) logger.warning("Key re-fetch failed: %s", exc)
finally:
_refetch_in_progress = False
asyncio.ensure_future(_refetch_key()) asyncio.ensure_future(_refetch_key())
# Cooldown: only re-key every 2s to avoid tight loops # Cooldown: only re-key every 0.5s for fast recovery (MAT-164)
last = _last_rekey_time.get(p_id, 0) last = _last_rekey_time.get(p_id, 0)
if (now - last) < 2.0: if (now - last) < 0.5:
return return
_last_rekey_time[p_id] = now _last_rekey_time[p_id] = now
if self._caller_all_keys: if self._caller_all_keys:
@@ -821,6 +856,9 @@ class VoiceSession:
logger.info("Connected (E2EE=HKDF), remote=%d", logger.info("Connected (E2EE=HKDF), remote=%d",
len(self.lk_room.remote_participants)) len(self.lk_room.remote_participants))
# Start continuous background key poller (MAT-164)
self._key_poller_task = asyncio.create_task(self._continuous_key_poller())
# Set bot's own key immediately after connect — local frame cryptor exists at connect time. # Set bot's own key immediately after connect — local frame cryptor exists at connect time.
# Pre-derive via HKDF in Python since KDF_RAW is set (no Rust-side derivation). # Pre-derive via HKDF in Python since KDF_RAW is set (no Rust-side derivation).
kp = self.lk_room.e2ee_manager.key_provider kp = self.lk_room.e2ee_manager.key_provider