import asyncio import logging from unittest.mock import AsyncMock, Mock, patch import pytest from nio.exceptions import OlmUnverifiedDeviceError from device_trust import CrossSignedOnlyPolicy class TestSendTextErrorHandling: """Test _send_text resilience against E2EE errors.""" def _make_bot(self, room_send_side_effect=None): """Create a minimal bot-like object with _send_text method.""" # Import the actual _send_text from bot module would pull too many deps, # so we replicate the patched logic here for unit testing. bot = Mock() bot.client = Mock() bot.client.room_send = AsyncMock(side_effect=room_send_side_effect) bot._md_to_html = Mock(return_value="
test
") return bot @pytest.mark.asyncio async def test_send_text_success(self): bot = self._make_bot() # Inline the method logic to test it await self._call_send_text(bot, "!room:ex.com", "hello") bot.client.room_send.assert_called_once() @pytest.mark.asyncio async def test_send_text_olm_unverified_does_not_crash(self, caplog): bot = self._make_bot( room_send_side_effect=OlmUnverifiedDeviceError("Device XYZABC not verified") ) with caplog.at_level(logging.ERROR): await self._call_send_text(bot, "!room:ex.com", "hello") assert "unverified device" in caplog.text @pytest.mark.asyncio async def test_send_text_generic_error_does_not_crash(self, caplog): bot = self._make_bot(room_send_side_effect=Exception("network timeout")) with caplog.at_level(logging.ERROR): await self._call_send_text(bot, "!room:ex.com", "hello") assert "Send failed" in caplog.text async def _call_send_text(self, bot, room_id, text): """Replicate _send_text logic matching the patched bot.py.""" logger = logging.getLogger("test_e2ee") try: await bot.client.room_send( room_id, message_type="m.room.message", content={ "msgtype": "m.text", "body": text, "format": "org.matrix.custom.html", "formatted_body": bot._md_to_html(text), }, ) except OlmUnverifiedDeviceError as e: logger.error("E2EE send failed in room %s: unverified device — %s", room_id, e) except Exception as e: logger.error("Send failed in room %s: %s", room_id, e) class TestOnSyncDeviceVerification: """Test on_sync blacklists unverified devices instead of skipping.""" def _make_device(self, device_id, cross_signed=False): device = Mock() device.device_id = device_id device.verified = False if cross_signed: device.signatures = { "@user:ex.com": { f"ed25519:{device_id}": "self", "ed25519:MASTER": "cross", } } else: device.signatures = { "@user:ex.com": {f"ed25519:{device_id}": "self"} } return device def test_blacklists_unverified_device(self): policy = CrossSignedOnlyPolicy() client = Mock() device = self._make_device("BADDEV", cross_signed=False) # Simulate on_sync logic if not device.verified: if policy.should_trust("@user:ex.com", device): client.verify_device(device) else: client.blacklist_device(device) client.blacklist_device.assert_called_once_with(device) client.verify_device.assert_not_called() def test_verifies_cross_signed_device(self): policy = CrossSignedOnlyPolicy() client = Mock() device = self._make_device("GOODDEV", cross_signed=True) if not device.verified: if policy.should_trust("@user:ex.com", device): client.verify_device(device) else: client.blacklist_device(device) client.verify_device.assert_called_once_with(device) client.blacklist_device.assert_not_called() def test_mixed_devices(self): policy = CrossSignedOnlyPolicy() client = Mock() good = self._make_device("GOOD", cross_signed=True) bad = self._make_device("BAD", cross_signed=False) for device in [good, bad]: if not device.verified: if policy.should_trust("@user:ex.com", device): client.verify_device(device) else: client.blacklist_device(device) client.verify_device.assert_called_once_with(good) client.blacklist_device.assert_called_once_with(bad)