"""Cross-signing manager for Matrix bot accounts. Handles bootstrapping, verification, and recovery of cross-signing keys. Reusable by any bot (matrix-ai-agent, claude-matrix-bridge, etc.). """ import base64 import json import logging import os import subprocess import canonicaljson import httpx import olm.pk logger = logging.getLogger(__name__) class CrossSigningManager: """Manages cross-signing keys for a Matrix bot device.""" def __init__(self, homeserver: str, store_path: str, bot_password: str, vault_key: str = ""): """ Args: vault_key: Vault key for seed backup/recovery (e.g. "matrix.ai.cross_signing_seeds"). If empty, vault backup is disabled. """ self.homeserver = homeserver self.store_path = store_path self.bot_password = bot_password self.vault_key = vault_key self._xsign_file = os.path.join(store_path, "cross_signing_keys.json") async def ensure_cross_signing(self, user_id: str, device_id: str, access_token: str) -> bool: """Ensure device is cross-signed. Returns True if already signed or newly signed.""" headers = {"Authorization": f"Bearer {access_token}"} # Check if device already has a cross-signing signature if await self._is_device_cross_signed(user_id, device_id, headers): logger.info("Device %s already cross-signed, skipping bootstrap", device_id) return True # Load existing seeds or generate new ones master_seed, ss_seed, us_seed = self._load_or_generate_seeds() master_key = olm.pk.PkSigning(master_seed) self_signing_key = olm.pk.PkSigning(ss_seed) user_signing_key = olm.pk.PkSigning(us_seed) # Upload cross-signing keys and sign device success = await self._upload_and_sign( user_id, device_id, access_token, master_key, self_signing_key, user_signing_key, ) if success: self._save_seeds(master_seed, ss_seed, us_seed) # Verify the signature was applied if await self._is_device_cross_signed(user_id, device_id, headers): logger.info("Device %s cross-signed and verified successfully", device_id) return True else: logger.error("Device %s: signature upload succeeded but verification failed — retrying once", device_id) # Retry signing the device (keys already uploaded) retry = await self._sign_device(user_id, device_id, access_token, self_signing_key) if retry and await self._is_device_cross_signed(user_id, device_id, headers): logger.info("Device %s cross-signed on retry", device_id) return True logger.error("Device %s cross-signing failed after retry", device_id) return False async def verify_cross_signing_status(self, user_id: str, device_id: str, access_token: str) -> dict: """Check cross-signing status. Returns dict with status details.""" headers = {"Authorization": f"Bearer {access_token}"} is_signed = await self._is_device_cross_signed(user_id, device_id, headers) has_seeds = os.path.exists(self._xsign_file) return { "device_id": device_id, "cross_signed": is_signed, "seeds_stored": has_seeds, } def get_public_keys(self) -> dict | None: """Return public key fingerprints (safe to log). Returns None if no seeds.""" if not os.path.exists(self._xsign_file): return None master_seed, ss_seed, us_seed = self._load_or_generate_seeds() return { "master": olm.pk.PkSigning(master_seed).public_key, "self_signing": olm.pk.PkSigning(ss_seed).public_key, "user_signing": olm.pk.PkSigning(us_seed).public_key, } # -- Private methods -- async def _is_device_cross_signed(self, user_id: str, device_id: str, headers: dict) -> bool: """Query server to check if device has a cross-signing signature.""" try: async with httpx.AsyncClient(timeout=15.0) as hc: resp = await hc.post( f"{self.homeserver}/_matrix/client/v3/keys/query", json={"device_keys": {user_id: [device_id]}}, headers=headers, ) if resp.status_code != 200: return False device_keys = resp.json().get("device_keys", {}).get(user_id, {}).get(device_id, {}) sigs = device_keys.get("signatures", {}).get(user_id, {}) device_key_id = f"ed25519:{device_id}" return any(k != device_key_id for k in sigs) except Exception as e: logger.warning("Cross-signing check failed: %s", e) return False def _load_or_generate_seeds(self) -> tuple[bytes, bytes, bytes]: """Load seeds from file, vault backup, or generate new ones.""" # 1. Try local file first if os.path.exists(self._xsign_file): with open(self._xsign_file) as f: seeds = json.load(f) logger.info("Loaded cross-signing seeds from local store") return ( base64.b64decode(seeds["master_seed"]), base64.b64decode(seeds["self_signing_seed"]), base64.b64decode(seeds["user_signing_seed"]), ) # 2. Try vault recovery if self.vault_key: recovered = self._vault_recover_seeds() if recovered: logger.info("Recovered cross-signing seeds from vault") return recovered # 3. Generate new logger.info("Generating new cross-signing keys") return os.urandom(32), os.urandom(32), os.urandom(32) def _save_seeds(self, master: bytes, ss: bytes, us: bytes) -> None: """Persist seeds to local file and vault.""" if os.path.exists(self._xsign_file): # File exists; still try vault backup if not yet stored self._vault_backup_seeds(master, ss, us) return os.makedirs(os.path.dirname(self._xsign_file), exist_ok=True) with open(self._xsign_file, "w") as f: json.dump({ "master_seed": base64.b64encode(master).decode(), "self_signing_seed": base64.b64encode(ss).decode(), "user_signing_seed": base64.b64encode(us).decode(), }, f) logger.info("Cross-signing seeds saved to %s", self._xsign_file) self._vault_backup_seeds(master, ss, us) def _vault_backup_seeds(self, master: bytes, ss: bytes, us: bytes) -> None: """Backup seeds to vault for disaster recovery.""" if not self.vault_key: return payload = json.dumps({ "master_seed": base64.b64encode(master).decode(), "self_signing_seed": base64.b64encode(ss).decode(), "user_signing_seed": base64.b64encode(us).decode(), }) try: result = subprocess.run( ["vault", "set", self.vault_key], input=payload, capture_output=True, text=True, timeout=10, ) if result.returncode == 0: logger.info("Cross-signing seeds backed up to vault key %s", self.vault_key) else: logger.warning("Vault backup failed: %s", result.stderr.strip()) except (FileNotFoundError, subprocess.TimeoutExpired) as e: logger.warning("Vault backup skipped: %s", e) def _vault_recover_seeds(self) -> tuple[bytes, bytes, bytes] | None: """Attempt to recover seeds from vault.""" if not self.vault_key: return None try: result = subprocess.run( ["vault", "get", self.vault_key], capture_output=True, text=True, timeout=10, ) if result.returncode != 0 or not result.stdout.strip(): return None seeds = json.loads(result.stdout.strip()) return ( base64.b64decode(seeds["master_seed"]), base64.b64decode(seeds["self_signing_seed"]), base64.b64decode(seeds["user_signing_seed"]), ) except (FileNotFoundError, subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e: logger.warning("Vault recovery failed: %s", e) return None async def _upload_and_sign( self, user_id: str, device_id: str, access_token: str, master_key: olm.pk.PkSigning, self_signing_key: olm.pk.PkSigning, user_signing_key: olm.pk.PkSigning, ) -> bool: """Upload cross-signing keys and sign device.""" headers = {"Authorization": f"Bearer {access_token}"} def make_key(usage: str, pubkey: str) -> dict: return { "user_id": user_id, "usage": [usage], "keys": {"ed25519:" + pubkey: pubkey}, } master_obj = make_key("master", master_key.public_key) ss_obj = make_key("self_signing", self_signing_key.public_key) us_obj = make_key("user_signing", user_signing_key.public_key) # Sign sub-keys with master key ss_canonical = canonicaljson.encode_canonical_json(ss_obj) ss_obj["signatures"] = {user_id: {"ed25519:" + master_key.public_key: master_key.sign(ss_canonical)}} us_canonical = canonicaljson.encode_canonical_json(us_obj) us_obj["signatures"] = {user_id: {"ed25519:" + master_key.public_key: master_key.sign(us_canonical)}} try: async with httpx.AsyncClient(timeout=15.0) as hc: # Upload cross-signing keys with password auth resp = await hc.post( f"{self.homeserver}/_matrix/client/v3/keys/device_signing/upload", json={ "master_key": master_obj, "self_signing_key": ss_obj, "user_signing_key": us_obj, "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": user_id}, "password": self.bot_password, }, }, headers=headers, ) if resp.status_code != 200: logger.error("Cross-signing upload failed: %d %s", resp.status_code, resp.text) return False logger.info("Cross-signing keys uploaded (master: %s)", master_key.public_key) # Sign the device return await self._sign_device(user_id, device_id, access_token, self_signing_key) except Exception as e: logger.error("Cross-signing bootstrap failed: %s", e, exc_info=True) return False async def _sign_device( self, user_id: str, device_id: str, access_token: str, self_signing_key: olm.pk.PkSigning, ) -> bool: """Sign a device with the self-signing key.""" headers = {"Authorization": f"Bearer {access_token}"} try: async with httpx.AsyncClient(timeout=15.0) as hc: qresp = await hc.post( f"{self.homeserver}/_matrix/client/v3/keys/query", json={"device_keys": {user_id: [device_id]}}, headers=headers, ) device_obj = qresp.json()["device_keys"][user_id][device_id] device_obj.pop("signatures", None) device_obj.pop("unsigned", None) dk_canonical = canonicaljson.encode_canonical_json(device_obj) dk_sig = self_signing_key.sign(dk_canonical) device_obj["signatures"] = { user_id: {"ed25519:" + self_signing_key.public_key: dk_sig}, } resp = await hc.post( f"{self.homeserver}/_matrix/client/v3/keys/signatures/upload", json={user_id: {device_id: device_obj}}, headers=headers, ) if resp.status_code != 200: logger.error("Device signature upload failed: %d %s", resp.status_code, resp.text) return False logger.info("Device %s signed with self-signing key", device_id) return True except Exception as e: logger.error("Device signing failed: %s", e, exc_info=True) return False