Rust FFI's KDF_HKDF path for incoming decryption may use wrong parameters. Pre-derive HKDF(base_key, salt="LKFrameEncryptionKey", info=identity) in Python and pass derived key with KDF_NONE so Rust FFI uses it directly as frame key. Matches EC's MatrixKeyProvider: ratchetWindowSize=10, keyringSize=256. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
405 lines
19 KiB
Python
405 lines
19 KiB
Python
"""Voice session: LiveKit + STT/LLM/TTS pipeline.
|
||
E2EE via HKDF key derivation (Element Call compatible).
|
||
Requires patched livekit-rtc FFI binary from onestacked/livekit-rust-sdks."""
|
||
import asyncio
|
||
import base64
|
||
import datetime
|
||
import hashlib
|
||
import logging
|
||
import os
|
||
|
||
import aiohttp
|
||
from livekit import rtc, api as lkapi
|
||
from livekit.agents import Agent, AgentSession, room_io
|
||
from livekit.plugins import openai as lk_openai, elevenlabs, silero
|
||
|
||
logger = logging.getLogger("matrix-ai-voice")
|
||
|
||
# Enable debug logging for agents pipeline to diagnose audio issues
|
||
logging.getLogger("livekit.agents").setLevel(logging.DEBUG)
|
||
logging.getLogger("livekit.plugins").setLevel(logging.DEBUG)
|
||
|
||
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||
LK_API_KEY = os.environ.get("LIVEKIT_API_KEY", "")
|
||
LK_API_SECRET = os.environ.get("LIVEKIT_API_SECRET", "")
|
||
ELEVENLABS_KEY = os.environ.get("ELEVENLABS_API_KEY", "")
|
||
DEFAULT_VOICE_ID = "onwK4e9ZLuTAKqWW03F9" # Daniel - male, free tier
|
||
|
||
VOICE_PROMPT = """Du bist ein hilfreicher Sprachassistent in einem Matrix-Anruf.
|
||
|
||
STRIKTE Regeln:
|
||
- Antworte IMMER auf Deutsch
|
||
- Halte JEDE Antwort auf MAXIMAL 1-2 kurze Saetze
|
||
- Sei direkt und praezise, keine Fuellwoerter
|
||
- Erfinde NICHTS - keine Geschichten, keine Musik, keine Fantasie
|
||
- Beantworte nur was gefragt wird
|
||
- Wenn niemand etwas fragt, sage nur kurz Hallo"""
|
||
|
||
_vad = None
|
||
def _get_vad():
|
||
global _vad
|
||
if _vad is None:
|
||
_vad = silero.VAD.load()
|
||
return _vad
|
||
|
||
def _make_lk_identity(user_id, device_id):
|
||
return f"{user_id}:{device_id}"
|
||
|
||
def _compute_lk_room_name(room_id):
|
||
raw = f"{room_id}|m.call#ROOM"
|
||
return base64.b64encode(hashlib.sha256(raw.encode()).digest()).decode().rstrip("=")
|
||
|
||
def _generate_lk_jwt(room_id, user_id, device_id):
|
||
identity = _make_lk_identity(user_id, device_id)
|
||
lk_room = _compute_lk_room_name(room_id)
|
||
token = (
|
||
lkapi.AccessToken(LK_API_KEY, LK_API_SECRET)
|
||
.with_identity(identity)
|
||
.with_name(user_id)
|
||
.with_grants(lkapi.VideoGrants(
|
||
room_join=True, room=lk_room,
|
||
can_publish=True, can_subscribe=True))
|
||
.with_ttl(datetime.timedelta(hours=24))
|
||
)
|
||
logger.info("JWT: identity=%s room=%s", identity, lk_room)
|
||
return token.to_jwt()
|
||
|
||
|
||
KDF_HKDF = 1
|
||
KDF_NONE = 0
|
||
|
||
_RATCHET_SALT = b"LKFrameEncryptionKey"
|
||
|
||
|
||
def _hkdf(ikm: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes:
|
||
"""HKDF-SHA256 (RFC 5869). Pre-derives frame key to bypass Rust FFI's HKDF."""
|
||
import hmac as _hmac, hashlib as _hashlib
|
||
prk = _hmac.new(salt, ikm, _hashlib.sha256).digest()
|
||
okm, t = b"", b""
|
||
for i in range(1, (length + 31) // 32 + 1):
|
||
t = _hmac.new(prk, t + info + bytes([i]), _hashlib.sha256).digest()
|
||
okm += t
|
||
return okm[:length]
|
||
|
||
|
||
def _build_e2ee_options() -> rtc.E2EEOptions:
|
||
"""Build E2EE options — KDF disabled; we pre-derive HKDF keys in Python.
|
||
|
||
The Rust FFI's KDF_HKDF path for INCOMING decryption may use wrong parameters.
|
||
We pre-derive HKDF(base_key, salt="LKFrameEncryptionKey", info=identity) in Python
|
||
and pass the derived key with KDF_NONE so the Rust FFI uses it directly.
|
||
Element Call uses: ratchetWindowSize=10, keyringSize=256, salt="LKFrameEncryptionKey"
|
||
"""
|
||
key_opts = rtc.KeyProviderOptions(
|
||
shared_key=b"", # empty = per-participant mode
|
||
ratchet_window_size=10,
|
||
ratchet_salt=b"LKFrameEncryptionKey",
|
||
failure_tolerance=-1,
|
||
key_ring_size=256,
|
||
key_derivation_function=KDF_NONE, # we pre-derive; FFI uses key directly
|
||
)
|
||
return rtc.E2EEOptions(key_provider_options=key_opts)
|
||
|
||
|
||
class VoiceSession:
|
||
def __init__(self, nio_client, room_id, device_id, lk_url, model="claude-sonnet",
|
||
publish_key_cb=None, bot_key: bytes | None = None):
|
||
self.nio_client = nio_client
|
||
self.room_id = room_id
|
||
self.device_id = device_id
|
||
self.lk_url = lk_url
|
||
self.model = model
|
||
self.lk_room = None
|
||
self.session = None
|
||
self._task = None
|
||
self._http_session = None
|
||
self._caller_key: bytes | None = None
|
||
self._caller_identity: str | None = None
|
||
self._caller_all_keys: dict = {} # {index: bytes} — all caller keys by index
|
||
self._bot_key: bytes = bot_key or os.urandom(16)
|
||
self._publish_key_cb = publish_key_cb
|
||
|
||
def on_encryption_key(self, sender, device_id, key, index):
|
||
"""Receive E2EE key from Element Call participant."""
|
||
if not key:
|
||
return
|
||
if not self._caller_key:
|
||
self._caller_key = key
|
||
self._caller_identity = f"{sender}:{device_id}"
|
||
self._caller_all_keys[index] = key
|
||
logger.info("E2EE key received from %s:%s (index=%d, %d bytes)",
|
||
sender, device_id, index, len(key))
|
||
# Live-update per-participant key on rotation — pre-derive HKDF matching KDF_NONE mode.
|
||
if self.lk_room and hasattr(self.lk_room, 'e2ee_manager'):
|
||
try:
|
||
kp = self.lk_room.e2ee_manager.key_provider
|
||
caller_id = self._caller_identity or f"{sender}:{device_id}"
|
||
derived = _hkdf(key, _RATCHET_SALT, caller_id.encode())
|
||
kp.set_key(caller_id, derived, index)
|
||
logger.info("Live-updated caller frame key[%d] for %s (%d→%d bytes)",
|
||
index, caller_id, len(key), len(derived))
|
||
except Exception as e:
|
||
logger.warning("Failed to live-update caller key: %s", e)
|
||
|
||
async def _fetch_encryption_key_http(self) -> bytes | None:
|
||
"""Fetch encryption key from room timeline (NOT state) via Matrix HTTP API.
|
||
|
||
Element Call distributes encryption keys as timeline events, not state.
|
||
"""
|
||
import httpx
|
||
homeserver = str(self.nio_client.homeserver)
|
||
token = self.nio_client.access_token
|
||
url = f"{homeserver}/_matrix/client/v3/rooms/{self.room_id}/messages"
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as http:
|
||
resp = await http.get(
|
||
url,
|
||
headers={"Authorization": f"Bearer {token}"},
|
||
params={"dir": "b", "limit": "50"},
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
events = data.get("chunk", [])
|
||
user_id = self.nio_client.user_id
|
||
for evt in events:
|
||
evt_type = evt.get("type", "")
|
||
if evt_type == "io.element.call.encryption_keys":
|
||
sender = evt.get("sender", "")
|
||
if sender == user_id:
|
||
continue # skip our own key
|
||
content = evt.get("content", {})
|
||
device = content.get("device_id", "")
|
||
logger.info("Found encryption_keys timeline event: sender=%s device=%s",
|
||
sender, device)
|
||
all_keys = {}
|
||
import base64 as b64
|
||
for k in content.get("keys", []):
|
||
key_b64 = k.get("key", "")
|
||
key_index = k.get("index", 0)
|
||
if key_b64:
|
||
key_b64 += "=" * (-len(key_b64) % 4)
|
||
key_bytes = b64.urlsafe_b64decode(key_b64)
|
||
all_keys[key_index] = key_bytes
|
||
if all_keys:
|
||
if device:
|
||
self._caller_identity = f"{sender}:{device}"
|
||
self._caller_all_keys.update(all_keys)
|
||
max_idx = max(all_keys.keys())
|
||
logger.info("Loaded caller keys at indices %s (using %d)",
|
||
sorted(all_keys.keys()), max_idx)
|
||
return all_keys[max_idx]
|
||
logger.info("No encryption_keys events in last %d timeline events", len(events))
|
||
except Exception as e:
|
||
logger.warning("HTTP encryption key fetch failed: %s", e)
|
||
return None
|
||
|
||
async def start(self):
|
||
self._task = asyncio.create_task(self._run())
|
||
|
||
async def stop(self):
|
||
for obj in [self.session, self.lk_room, self._http_session]:
|
||
if obj:
|
||
try:
|
||
if hasattr(obj, "aclose"):
|
||
await obj.aclose()
|
||
elif hasattr(obj, "disconnect"):
|
||
await obj.disconnect()
|
||
elif hasattr(obj, "close"):
|
||
await obj.close()
|
||
except Exception:
|
||
pass
|
||
if self._task and not self._task.done():
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
async def _run(self):
|
||
try:
|
||
user_id = self.nio_client.user_id
|
||
bot_identity = _make_lk_identity(user_id, self.device_id)
|
||
jwt = _generate_lk_jwt(self.room_id, user_id, self.device_id)
|
||
|
||
# Publish bot's own key immediately so Element Call can decrypt us
|
||
if self._publish_key_cb:
|
||
self._publish_key_cb(self._bot_key)
|
||
logger.info("Published bot E2EE key (%d bytes)", len(self._bot_key))
|
||
|
||
# Check timeline for caller's encryption key
|
||
caller_key = await self._fetch_encryption_key_http()
|
||
if caller_key:
|
||
self._caller_key = caller_key
|
||
logger.info("Got caller E2EE key via timeline (%d bytes)", len(caller_key))
|
||
|
||
if not self._caller_key:
|
||
# Wait up to 15s for key via sync handler
|
||
logger.info("No key in timeline yet, waiting for sync...")
|
||
for _ in range(150):
|
||
if self._caller_key:
|
||
break
|
||
await asyncio.sleep(0.1)
|
||
|
||
# Connect in per-participant mode (empty shared_key) so Rust FFI uses
|
||
# identity-based HKDF — matching Element Call's JS SFrame key derivation.
|
||
# Keys are set post-connect via set_key(identity, key, index).
|
||
e2ee_opts = _build_e2ee_options()
|
||
room_opts = rtc.RoomOptions(e2ee=e2ee_opts)
|
||
self.lk_room = rtc.Room()
|
||
|
||
@self.lk_room.on("participant_connected")
|
||
def on_p(p):
|
||
logger.info("Participant connected: %s", p.identity)
|
||
|
||
@self.lk_room.on("track_published")
|
||
def on_tp(pub, p):
|
||
logger.info("Track pub: %s %s kind=%s", p.identity, pub.sid, pub.kind)
|
||
|
||
@self.lk_room.on("track_subscribed")
|
||
def on_ts(t, pub, p):
|
||
logger.info("Track sub: %s %s kind=%s", p.identity, pub.sid, t.kind)
|
||
|
||
_e2ee_state_names = {0:"NEW",1:"OK",2:"ENC_FAILED",3:"DEC_FAILED",4:"MISSING_KEY",5:"RATCHETED",6:"INTERNAL_ERR"}
|
||
@self.lk_room.on("e2ee_state_changed")
|
||
def on_e2ee_state(participant, state):
|
||
state_name = _e2ee_state_names.get(int(state), f"UNKNOWN_{state}")
|
||
p_id = participant.identity if participant else "local"
|
||
logger.info("E2EE_STATE: participant=%s state=%s", p_id, state_name)
|
||
|
||
await self.lk_room.connect(self.lk_url, jwt, options=room_opts)
|
||
logger.info("Connected (E2EE=HKDF), remote=%d",
|
||
len(self.lk_room.remote_participants))
|
||
|
||
# Element Call rotates its encryption key when bot joins the LiveKit room.
|
||
# EC sends the new key via Matrix (Megolm-encrypted); nio sync will decrypt it
|
||
# and call on_encryption_key(), which updates self._caller_all_keys.
|
||
# NOTE: HTTP fetch is useless here — keys are Matrix-E2EE encrypted (m.room.encrypted).
|
||
pre_max_idx = max(self._caller_all_keys.keys()) if self._caller_all_keys else -1
|
||
logger.info("Waiting for EC key rotation via nio sync (current max_idx=%d)...", pre_max_idx)
|
||
for _attempt in range(20): # up to 10s (20 × 0.5s)
|
||
await asyncio.sleep(0.5)
|
||
new_max = max(self._caller_all_keys.keys()) if self._caller_all_keys else -1
|
||
if new_max > pre_max_idx:
|
||
self._caller_key = self._caller_all_keys[new_max]
|
||
logger.info("Key rotated: index %d→%d (%d bytes)",
|
||
pre_max_idx, new_max, len(self._caller_key))
|
||
break
|
||
if _attempt % 4 == 3: # log every 2s
|
||
logger.info("Key rotation wait %ds: max_idx still %d", (_attempt + 1) // 2, new_max)
|
||
else:
|
||
logger.warning("No key rotation after 10s — using pre-join key[%d]", pre_max_idx)
|
||
|
||
# Set per-participant keys via key provider.
|
||
# We pre-derive HKDF(base_key, salt=ratchetSalt, info=identity) in Python
|
||
# and pass the derived key with KDF_NONE so the Rust FFI uses it directly.
|
||
# This matches Element Call's JS E2EE worker derivation exactly.
|
||
kp = self.lk_room.e2ee_manager.key_provider
|
||
|
||
# Bot's own key — pre-derive HKDF then set for outgoing encryption
|
||
bot_frame_key = _hkdf(self._bot_key, _RATCHET_SALT, bot_identity.encode())
|
||
kp.set_key(bot_identity, bot_frame_key, 0)
|
||
logger.info("Set bot frame key for %s (base=%d→derived=%d bytes)",
|
||
bot_identity, len(self._bot_key), len(bot_frame_key))
|
||
|
||
# Find the remote participant, wait up to 10s if not yet connected
|
||
remote_identity = None
|
||
for p in self.lk_room.remote_participants.values():
|
||
remote_identity = p.identity
|
||
break
|
||
if not remote_identity:
|
||
logger.info("No remote participant yet, waiting...")
|
||
for _ in range(100):
|
||
await asyncio.sleep(0.1)
|
||
for p in self.lk_room.remote_participants.values():
|
||
remote_identity = p.identity
|
||
break
|
||
if remote_identity:
|
||
break
|
||
|
||
# Set ALL known caller keys — pre-derive HKDF(base_key, ratchetSalt, identity).
|
||
# EC encrypts user audio with HKDF(user_base_key, "LKFrameEncryptionKey", user_identity).
|
||
# With KDF_NONE, the Rust FFI uses the key directly, so we must pre-derive.
|
||
if self._caller_all_keys and remote_identity:
|
||
try:
|
||
for idx, base_k in sorted(self._caller_all_keys.items()):
|
||
derived_k = _hkdf(base_k, _RATCHET_SALT, remote_identity.encode())
|
||
kp.set_key(remote_identity, derived_k, idx)
|
||
logger.info("Set caller frame key[%d] for %s (base=%d→derived=%d bytes)",
|
||
idx, remote_identity, len(base_k), len(derived_k))
|
||
# Belt+suspenders: also set via matrix identity if different from LK identity
|
||
if self._caller_identity and self._caller_identity != remote_identity:
|
||
for idx, base_k in sorted(self._caller_all_keys.items()):
|
||
derived_k = _hkdf(base_k, _RATCHET_SALT, self._caller_identity.encode())
|
||
kp.set_key(self._caller_identity, derived_k, idx)
|
||
logger.info("Also set caller keys via matrix identity %s", self._caller_identity)
|
||
except Exception as e:
|
||
logger.warning("Failed to set caller per-participant keys: %s", e)
|
||
elif not self._caller_all_keys:
|
||
logger.warning("No caller E2EE keys — incoming audio will be silence")
|
||
elif not remote_identity:
|
||
logger.warning("No remote participant found — caller keys not set")
|
||
|
||
if remote_identity:
|
||
logger.info("Linking to remote participant: %s", remote_identity)
|
||
|
||
# Voice pipeline — German male voice (Daniel)
|
||
self._http_session = aiohttp.ClientSession()
|
||
voice_id = os.environ.get("ELEVENLABS_VOICE_ID", DEFAULT_VOICE_ID)
|
||
self.session = AgentSession(
|
||
stt=elevenlabs.STT(api_key=ELEVENLABS_KEY, http_session=self._http_session),
|
||
llm=lk_openai.LLM(base_url=LITELLM_URL, api_key=LITELLM_KEY, model=self.model),
|
||
tts=elevenlabs.TTS(voice_id=voice_id, model="eleven_multilingual_v2",
|
||
api_key=ELEVENLABS_KEY, http_session=self._http_session),
|
||
vad=_get_vad(),
|
||
)
|
||
|
||
# Debug: log speech events
|
||
@self.session.on("user_speech_committed")
|
||
def _on_user_speech(msg):
|
||
logger.info("USER_SPEECH: %s", msg.text_content)
|
||
|
||
@self.session.on("agent_speech_committed")
|
||
def _on_agent_speech(msg):
|
||
logger.info("AGENT_SPEECH: %s", msg.text_content)
|
||
|
||
agent = Agent(instructions=VOICE_PROMPT)
|
||
io_opts = room_io.RoomOptions(
|
||
participant_identity=remote_identity,
|
||
close_on_disconnect=False,
|
||
) if remote_identity else room_io.RoomOptions(close_on_disconnect=False)
|
||
await self.session.start(
|
||
agent=agent,
|
||
room=self.lk_room,
|
||
room_options=io_opts,
|
||
)
|
||
logger.info("Voice pipeline started (voice=%s, linked_to=%s)", voice_id, remote_identity)
|
||
|
||
try:
|
||
await asyncio.wait_for(
|
||
self.session.generate_reply(
|
||
instructions="Sage nur: Hallo, wie kann ich helfen?"),
|
||
timeout=30.0)
|
||
logger.info("Greeting sent")
|
||
except asyncio.TimeoutError:
|
||
logger.error("Greeting timed out")
|
||
|
||
# Periodic E2EE state poll to diagnose incoming decryption
|
||
_poll_count = 0
|
||
while True:
|
||
await asyncio.sleep(5)
|
||
_poll_count += 1
|
||
if _poll_count % 6 == 0: # Every 30s
|
||
try:
|
||
cryptors = self.lk_room.e2ee_manager.frame_cryptors()
|
||
for c in cryptors:
|
||
logger.info("E2EE_CRYPTOR: identity=%s key_index=%d enabled=%s",
|
||
c.participant_identity, c.key_index, c.enabled)
|
||
except Exception as e:
|
||
logger.debug("frame_cryptors() failed: %s", e)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Session cancelled for %s", self.room_id)
|
||
except Exception:
|
||
logger.exception("Session failed for %s", self.room_id)
|