Compare commits
246 Commits
9bd7f27a84
...
session/CF
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae2f34a3b6 | ||
|
|
f16c94b2dc | ||
|
|
7e5b39ea72 | ||
|
|
f4bdae7a1e | ||
|
|
62dbf7b37b | ||
|
|
7b5c157b12 | ||
|
|
c2985488c4 | ||
|
|
f57c91589f | ||
|
|
5d7cd8731d | ||
|
|
874ed70d66 | ||
|
|
ae69ea487c | ||
|
|
e658f3d2ac | ||
|
|
8350a0cb1f | ||
|
|
b0125bf68e | ||
|
|
b8f62ac38f | ||
|
|
776b1af3a0 | ||
|
|
dec6eee726 | ||
|
|
621aca19ad | ||
|
|
dd904c6928 | ||
|
|
70b0b89290 | ||
|
|
d11516f632 | ||
|
|
c604b5f644 | ||
|
|
c11dd73ce3 | ||
|
|
3363b4238f | ||
|
|
fafd440506 | ||
|
|
0ae59c8ebe | ||
|
|
b3e6ae65de | ||
|
|
07d781d101 | ||
|
|
7fd3aae176 | ||
|
|
bfc717372c | ||
|
|
dcee2d30d5 | ||
|
|
cb5f057006 | ||
|
|
b69980d57f | ||
|
|
0988f636d0 | ||
|
|
18b88d490f | ||
|
|
c29c2170f3 | ||
|
|
183c41a72b | ||
|
|
6ee3e74b1d | ||
|
|
b4425fc9e9 | ||
|
|
006ed48cbe | ||
|
|
909c128a65 | ||
|
|
a0b410337a | ||
|
|
ad2caa90e7 | ||
|
|
9f6132654c | ||
|
|
6b226acb13 | ||
|
|
9a223cd305 | ||
|
|
21b6e78e83 | ||
|
|
1d7730fbf7 | ||
|
|
3aac724627 | ||
|
|
ac0aefcfeb | ||
|
|
8862ed2596 | ||
|
|
8970179493 | ||
|
|
6084cd7d45 | ||
|
|
ae8e6d7658 | ||
|
|
5f3b8ef1d8 | ||
|
|
d86401fe93 | ||
|
|
ea0df8c223 | ||
|
|
b6acfca59d | ||
|
|
4463cdfee9 | ||
|
|
5724195fe0 | ||
|
|
0f251ddc37 | ||
|
|
f3c7f994a2 | ||
|
|
e44e89aa00 | ||
|
|
a22a922b43 | ||
|
|
e94a5cfee4 | ||
|
|
f27d545012 | ||
|
|
de3d67f756 | ||
|
|
4e92b8c053 | ||
|
|
69ac33eb0a | ||
|
|
62d11ddfa8 | ||
|
|
57cb676b93 | ||
|
|
7f9769577b | ||
|
|
c6ccacee9e | ||
|
|
89aa46aeb2 | ||
|
|
d985f9a593 | ||
|
|
95d5aa72f2 | ||
|
|
bd8d96335e | ||
|
|
f4feb3bfe1 | ||
|
|
ef960844e5 | ||
|
|
c9b88a155b | ||
|
|
3928e85279 | ||
|
|
bd65b12e5d | ||
|
|
6937b91fe3 | ||
|
|
3d5351a24a | ||
|
|
441def5fa3 | ||
|
|
9ec45339e9 | ||
|
|
ec46c37bc5 | ||
|
|
19b72dfe07 | ||
|
|
846634738b | ||
|
|
193c3ad329 | ||
|
|
5e2a7715a1 | ||
|
|
4d8ea44b3d | ||
|
|
21b8a4efb1 | ||
|
|
62cc2a92fe | ||
|
|
cdd876fe24 | ||
|
|
1a0a2ec305 | ||
|
|
488e50e73c | ||
|
|
3706f568b6 | ||
|
|
a155f39ede | ||
|
|
5521819358 | ||
|
|
f73de35fd4 | ||
|
|
0c7070ebc4 | ||
|
|
4ae65524ac | ||
|
|
f85562ed28 | ||
|
|
4fc268cdd7 | ||
|
|
1118ab5060 | ||
|
|
7e59593c3e | ||
|
|
61531d9913 | ||
|
|
5ad1d1d60c | ||
|
|
911a48330f | ||
|
|
30ade51d4a | ||
|
|
7b563d39b3 | ||
|
|
c61bcffec2 | ||
|
|
e63fc2b680 | ||
|
|
d586ddfa6d | ||
|
|
19a973b9eb | ||
|
|
2b92b99292 | ||
|
|
4a37f7e9ef | ||
|
|
cb539860d9 | ||
|
|
c71e3871b6 | ||
|
|
f5e08257eb | ||
|
|
3c06ededdf | ||
|
|
964a3f6075 | ||
|
|
19abea01ca | ||
|
|
06b876bdea | ||
|
|
c703c3a85c | ||
|
|
875ff74f47 | ||
|
|
1c8d45c31b | ||
|
|
9fcdedc4b4 | ||
|
|
d6dae1da8e | ||
|
|
c8e5cd84bf | ||
|
|
3e31e2a18c | ||
|
|
36c7e36456 | ||
|
|
e584ce8ce0 | ||
|
|
efb976a27c | ||
|
|
947699c988 | ||
|
|
108144696b | ||
|
|
0d83d3177e | ||
|
|
ae059749c4 | ||
|
|
6fb8c33057 | ||
|
|
f1529013ca | ||
|
|
b925786867 | ||
|
|
aa175b8fb9 | ||
|
|
e2bac92959 | ||
|
|
4ec4054db4 | ||
|
|
1000891a97 | ||
|
|
90cdc7b812 | ||
|
|
9578e0406b | ||
|
|
5d3a6c8c79 | ||
|
|
df9eaa99ec | ||
|
|
d9d2c0a849 | ||
|
|
f3db53798d | ||
|
|
100f85e990 | ||
|
|
b0f84670f2 | ||
|
|
3c3eb196e1 | ||
|
|
4bed67ac7f | ||
|
|
c2d611ace8 | ||
|
|
4d6cba1f0c | ||
|
|
a4f01ca177 | ||
|
|
d905f6ca6f | ||
|
|
fecf99ef60 | ||
|
|
9d2e2ddcf7 | ||
|
|
fb54ac2bea | ||
|
|
6fe9607fb1 | ||
|
|
34f403a066 | ||
|
|
18607e39b5 | ||
|
|
7915d11463 | ||
|
|
490822f3c3 | ||
|
|
1db4f1f3bd | ||
|
|
2826455036 | ||
|
|
e880376fdb | ||
|
|
40a99c73f7 | ||
|
|
5d730739b8 | ||
|
|
2716f1946a | ||
|
|
7493df3b2c | ||
|
|
7791a5ba8e | ||
|
|
db10e435bc | ||
|
|
10762a53da | ||
|
|
9833c89aa6 | ||
|
|
3bf9229ae4 | ||
|
|
b19300d3ce | ||
|
|
a3365626ae | ||
|
|
11b80f07c6 | ||
|
|
9a879f566d | ||
|
|
3a5d37fac2 | ||
|
|
f3b6f3f2f0 | ||
|
|
48f6e7dd17 | ||
|
|
08a3c4a9cc | ||
|
|
9958fb9b6b | ||
|
|
b492abe0c9 | ||
|
|
3ea4d5abc8 | ||
|
|
9e146da3b0 | ||
|
|
3e60e822be | ||
|
|
326a874aa7 | ||
|
|
cfb26fb351 | ||
|
|
6081f9a7ec | ||
|
|
de66ba5eea | ||
|
|
a4b5c5da86 | ||
|
|
6a6f9ef1c4 | ||
|
|
c5e1c79e1b | ||
|
|
4a0679d1dc | ||
|
|
b275e7cb88 | ||
|
|
e81aa79396 | ||
|
|
751bfbd164 | ||
|
|
040d4c9285 | ||
|
|
42ba3c09d0 | ||
|
|
90e662be96 | ||
|
|
1ec63b93f2 | ||
|
|
e84260f839 | ||
|
|
277d6b5fe4 | ||
|
|
a11cafc1d6 | ||
|
|
150df19be1 | ||
|
|
294fbac913 | ||
|
|
6443aa0668 | ||
|
|
c532f4678d | ||
|
|
4b4a150fbf | ||
|
|
230c083b7b | ||
|
|
d30e9f8c83 | ||
|
|
62be6b91d6 | ||
|
|
220ad6cced | ||
|
|
ea52236880 | ||
|
|
5bfe0d0188 | ||
|
|
e3be4512d9 | ||
|
|
c2338fca46 | ||
|
|
7b7079352f | ||
|
|
5984132f60 | ||
|
|
9e0f2a15b6 | ||
|
|
c38ab96054 | ||
|
|
38c3d93adf | ||
|
|
fa9e95b250 | ||
|
|
7f03cc1f37 | ||
|
|
6c1073e79d | ||
|
|
a8d4663f10 | ||
|
|
06b588f313 | ||
|
|
e926908af7 | ||
|
|
fb09808a8c | ||
|
|
8f80e7d543 | ||
|
|
125b0f5d2e | ||
|
|
1b08683c17 | ||
|
|
8445c9325c | ||
|
|
e090c60c19 | ||
|
|
1e1911995f | ||
|
|
02a7c91eaf | ||
|
|
39ef4e0054 | ||
|
|
2dce8419d4 | ||
|
|
382a98dd09 |
61
.gitea/workflows/deploy.yml
Normal file
61
.gitea/workflows/deploy.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
name: Build & Deploy
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths-ignore: ['**.md', 'docs/**']
|
||||
env:
|
||||
REGISTRY: gitea.agiliton.internal:3000
|
||||
IMAGE: gitea.agiliton.internal:3000/christian/matrix-ai-agent
|
||||
TARGET_VM: matrix.agiliton.internal
|
||||
DEPLOY_PATH: /opt/matrix-ai-agent
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: Install dependencies
|
||||
run: pip install -r requirements.txt -r requirements-test.txt
|
||||
- name: Run tests
|
||||
run: pytest tests/ -v --cov=device_trust --cov-report=term
|
||||
build-and-deploy:
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Setup SSH
|
||||
run: |
|
||||
mkdir -p ~/.ssh && chmod 700 ~/.ssh
|
||||
echo "${{ secrets.SSH_PRIVATE_KEY }}" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh-keyscan -p 2222 gitea-ssh.agiliton.internal >> ~/.ssh/known_hosts 2>/dev/null || true
|
||||
ssh-keyscan -H ${{ env.TARGET_VM }} >> ~/.ssh/known_hosts 2>/dev/null || true
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
- name: Login & Build & Push
|
||||
run: |
|
||||
echo "${{ secrets.REGISTRY_TOKEN }}" | docker login ${{ env.REGISTRY }} -u christian --password-stdin
|
||||
DOCKER_BUILDKIT=1 docker build --pull -t ${{ env.IMAGE }}:latest .
|
||||
docker push ${{ env.IMAGE }}:latest
|
||||
- name: Deploy
|
||||
run: |
|
||||
ssh root@${{ env.TARGET_VM }} << 'EOF'
|
||||
cd ${{ env.DEPLOY_PATH }} && git pull origin main --ff-only 2>/dev/null || true
|
||||
docker pull ${{ env.IMAGE }}:latest
|
||||
docker compose up -d --force-recreate --remove-orphans
|
||||
EOF
|
||||
- name: Smoke test
|
||||
run: |
|
||||
ssh root@${{ env.TARGET_VM }} << 'EOF'
|
||||
sleep 15
|
||||
docker exec matrix-ai-agent-bot-1 python3 -c "
|
||||
from bot import BOT_USER
|
||||
print(f'Bot user: {BOT_USER}')
|
||||
print('Smoke test passed')
|
||||
" || exit 1
|
||||
EOF
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: docker builder prune -f --filter "until=24h" 2>/dev/null || true
|
||||
21
.gitea/workflows/test.yml
Normal file
21
.gitea/workflows/test.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Tests
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths-ignore: ['**.md', 'docs/**']
|
||||
pull_request:
|
||||
branches: [main]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt -r requirements-test.txt
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest tests/ -v --cov=device_trust --cov-report=term
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,3 +3,6 @@ __pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
.claude-session/
|
||||
*.bak
|
||||
*.bak.*
|
||||
memory-db-ssl/
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "confluence-collab"]
|
||||
path = confluence-collab
|
||||
url = ssh://git@gitea-ssh.agiliton.internal:2222/christian/confluence-collab.git
|
||||
10
Dockerfile
10
Dockerfile
@@ -1,6 +1,7 @@
|
||||
# Stage 1: Build patched Rust FFI with HKDF support for Element Call E2EE
|
||||
# Fork: onestacked/livekit-rust-sdks branch EC-compat-changes
|
||||
# PR: https://github.com/livekit/rust-sdks/pull/904
|
||||
# NOTE: PR #921 (native HKDF at C++ level) requires custom WebRTC build not yet available.
|
||||
# Must use rust:latest (trixie/sid) — bookworm GCC 12 can't compile webrtc C++20 code
|
||||
FROM rust:latest AS rust-build
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
@@ -10,6 +11,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
WORKDIR /build
|
||||
RUN git clone --branch EC-compat-changes --depth 1 --recurse-submodules \
|
||||
https://github.com/onestacked/livekit-rust-sdks.git
|
||||
# Patch HKDF: limit output to 16 bytes (AES-128) matching Element Call JS SDK
|
||||
# deriveKey({name:"AES-GCM", length:128}). C++ buffer may be larger but we
|
||||
# only fill first 16 bytes to match the JS-derived key.
|
||||
COPY hkdf_fix.py /tmp/hkdf_fix.py
|
||||
RUN python3 /tmp/hkdf_fix.py /build/livekit-rust-sdks/livekit/src/room/e2ee/key_provider.rs
|
||||
WORKDIR /build/livekit-rust-sdks/livekit-ffi
|
||||
RUN cargo build --release
|
||||
|
||||
@@ -52,6 +58,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Install confluence-collab for section-based editing (CF-1812)
|
||||
COPY confluence-collab/ /tmp/confluence-collab/
|
||||
RUN pip install --no-cache-dir /tmp/confluence-collab/ && rm -rf /tmp/confluence-collab/
|
||||
|
||||
# Overwrite installed FFI binary with patched version (HKDF + key_ring_size support)
|
||||
COPY --from=rust-build /build/livekit-rust-sdks/target/release/liblivekit_ffi.so /patched/
|
||||
ENV LIVEKIT_LIB_PATH=/patched/liblivekit_ffi.so
|
||||
|
||||
@@ -3,4 +3,9 @@ WORKDIR /app
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg libolm-dev && rm -rf /var/lib/apt/lists/*
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Install confluence-collab for section-based editing (CF-1812)
|
||||
COPY confluence-collab/ /tmp/confluence-collab/
|
||||
RUN pip install --no-cache-dir /tmp/confluence-collab/ && rm -rf /tmp/confluence-collab/
|
||||
|
||||
COPY . .
|
||||
|
||||
161
activity_video.py
Normal file
161
activity_video.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Activity video track — pulsing orb (lightweight).
|
||||
|
||||
Small 160x120 canvas, only renders pixels near the orb.
|
||||
LiveKit/browser upscales. Minimal CPU on both server and client.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
import time
|
||||
import struct
|
||||
|
||||
from livekit.rtc import VideoSource, VideoFrame, VideoBufferType
|
||||
|
||||
logger = logging.getLogger("activity-video")
|
||||
|
||||
WIDTH = 160
|
||||
HEIGHT = 120
|
||||
FPS = 15
|
||||
BPP = 4
|
||||
CX, CY = WIDTH // 2, HEIGHT // 2
|
||||
|
||||
BG = (12, 12, 28)
|
||||
|
||||
STATE_COLORS = {
|
||||
"listening": (40, 120, 255),
|
||||
"thinking": (100, 60, 255),
|
||||
"speaking": (30, 200, 255),
|
||||
"initializing": (40, 60, 120),
|
||||
}
|
||||
|
||||
_BG_PIXEL = struct.pack('BBBB', *BG, 255)
|
||||
_BG_FRAME = _BG_PIXEL * (WIDTH * HEIGHT)
|
||||
|
||||
# Pre-compute distance from center — only within max possible glow radius
|
||||
MAX_ORB = 45 # max orb radius at full energy
|
||||
MAX_GLOW = int(MAX_ORB * 2.5) + 5
|
||||
# Store sparse: list of (pixel_index, distance) for pixels within MAX_GLOW of center
|
||||
_PIXELS = []
|
||||
for _y in range(max(0, CY - MAX_GLOW), min(HEIGHT, CY + MAX_GLOW + 1)):
|
||||
dy = _y - CY
|
||||
for _x in range(max(0, CX - MAX_GLOW), min(WIDTH, CX + MAX_GLOW + 1)):
|
||||
dx = _x - CX
|
||||
d = math.sqrt(dx * dx + dy * dy)
|
||||
if d <= MAX_GLOW:
|
||||
_PIXELS.append((_y * WIDTH + _x, d))
|
||||
|
||||
|
||||
class ActivityVideoPublisher:
|
||||
def __init__(self):
|
||||
self.source = VideoSource(WIDTH, HEIGHT)
|
||||
self._state = "initializing"
|
||||
self._stopped = False
|
||||
self._pulse = 0.0
|
||||
self._energy = 0.0
|
||||
self._target_energy = 0.0
|
||||
self._color = list(STATE_COLORS["initializing"])
|
||||
self._target_color = list(STATE_COLORS["initializing"])
|
||||
self._ring_phase = 0.0
|
||||
|
||||
def set_state(self, state: str):
|
||||
if self._state != state:
|
||||
logger.info("Activity video state: %s -> %s", self._state, state)
|
||||
self._state = state
|
||||
self._target_color = list(STATE_COLORS.get(state, STATE_COLORS["initializing"]))
|
||||
|
||||
def stop(self):
|
||||
self._stopped = True
|
||||
|
||||
def _update(self, t: float):
|
||||
state = self._state
|
||||
for i in range(3):
|
||||
self._color[i] += (self._target_color[i] - self._color[i]) * 0.08
|
||||
|
||||
if state == "listening":
|
||||
self._target_energy = 0.3
|
||||
self._pulse = 0.5 * math.sin(t * 1.5) + 0.5
|
||||
elif state == "thinking":
|
||||
self._target_energy = 0.6
|
||||
self._pulse = 0.5 * math.sin(t * 3.0) + 0.5
|
||||
elif state == "speaking":
|
||||
self._target_energy = 0.9 + random.uniform(-0.1, 0.1)
|
||||
self._pulse = 0.5 * math.sin(t * 6.0) + 0.5 + random.uniform(-0.15, 0.15)
|
||||
else:
|
||||
self._target_energy = 0.15
|
||||
self._pulse = 0.3
|
||||
|
||||
self._energy += (self._target_energy - self._energy) * 0.12
|
||||
self._ring_phase = t
|
||||
|
||||
def _render_frame(self) -> bytearray:
|
||||
buf = bytearray(_BG_FRAME)
|
||||
|
||||
r, g, b = self._color
|
||||
energy = self._energy
|
||||
pulse = self._pulse
|
||||
bg_r, bg_g, bg_b = BG
|
||||
|
||||
base_radius = 15 + 8 * energy
|
||||
orb_radius = base_radius + 4 * pulse * energy
|
||||
glow_radius = orb_radius * 2.5
|
||||
inv_orb = 1.0 / max(orb_radius, 1)
|
||||
glow_span = glow_radius - orb_radius
|
||||
inv_glow = 1.0 / max(glow_span, 1)
|
||||
|
||||
ring_active = self._state == "speaking"
|
||||
if ring_active:
|
||||
ring1_r = orb_radius + ((self._ring_phase * 30) % glow_span)
|
||||
ring2_r = orb_radius + ((self._ring_phase * 30 + glow_span * 0.5) % glow_span)
|
||||
|
||||
for idx, dist in _PIXELS:
|
||||
if dist > glow_radius:
|
||||
continue
|
||||
|
||||
if dist <= orb_radius:
|
||||
f = dist * inv_orb
|
||||
brightness = 1.0 - 0.3 * f * f
|
||||
white = max(0.0, 1.0 - f * 2.5) * 0.6 * energy
|
||||
pr = min(255, int(r * brightness + 255 * white))
|
||||
pg = min(255, int(g * brightness + 255 * white))
|
||||
pb = min(255, int(b * brightness + 255 * white))
|
||||
else:
|
||||
f = (dist - orb_radius) * inv_glow
|
||||
t3 = 1.0 - f
|
||||
glow = t3 * t3 * t3 * energy * 0.5
|
||||
|
||||
if ring_active:
|
||||
for rr in (ring1_r, ring2_r):
|
||||
rd = abs(dist - rr)
|
||||
if rd < 4:
|
||||
glow += (1.0 - rd * 0.25) * 0.3 * (1.0 - f)
|
||||
|
||||
pr = min(255, int(bg_r + r * glow))
|
||||
pg = min(255, int(bg_g + g * glow))
|
||||
pb = min(255, int(bg_b + b * glow))
|
||||
|
||||
off = idx * BPP
|
||||
buf[off] = pr
|
||||
buf[off + 1] = pg
|
||||
buf[off + 2] = pb
|
||||
|
||||
return buf
|
||||
|
||||
async def run(self):
|
||||
logger.info("Activity video loop started (%dx%d @ %d FPS, orb mode, %d active pixels)",
|
||||
WIDTH, HEIGHT, FPS, len(_PIXELS))
|
||||
interval = 1.0 / FPS
|
||||
t0 = time.monotonic()
|
||||
rgba_type = VideoBufferType.Value('RGBA')
|
||||
|
||||
while not self._stopped:
|
||||
t = time.monotonic() - t0
|
||||
self._update(t)
|
||||
buf = self._render_frame()
|
||||
frame = VideoFrame(WIDTH, HEIGHT, rgba_type, buf)
|
||||
self.source.capture_frame(frame)
|
||||
render_time = time.monotonic() - t0 - t
|
||||
await asyncio.sleep(max(0.001, interval - render_time))
|
||||
|
||||
logger.info("Activity video loop stopped")
|
||||
26
agent.py
26
agent.py
@@ -1,17 +1,27 @@
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from livekit.agents import Agent, AgentSession, AgentServer, JobContext, JobProcess, cli
|
||||
from livekit.plugins import openai as lk_openai, elevenlabs, silero
|
||||
import livekit.rtc as rtc
|
||||
|
||||
from e2ee_patch import KDF_HKDF
|
||||
from activity_video import ActivityVideoPublisher
|
||||
|
||||
logger = logging.getLogger("matrix-ai-agent")
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Sentry error tracking
|
||||
_sentry_dsn = os.environ.get("SENTRY_DSN", "")
|
||||
if _sentry_dsn:
|
||||
sentry_sdk.init(dsn=_sentry_dsn, traces_sample_rate=0.1, environment=os.environ.get("SENTRY_ENV", "production"))
|
||||
logger.info("Sentry initialized for agent")
|
||||
|
||||
LITELLM_URL = os.environ["LITELLM_BASE_URL"]
|
||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||
|
||||
@@ -20,7 +30,10 @@ Rules:
|
||||
- Keep answers SHORT — 1-3 sentences max
|
||||
- Be direct, no filler words
|
||||
- If the user wants more detail, they will ask
|
||||
- Speak naturally as in a conversation"""
|
||||
- Speak naturally as in a conversation
|
||||
- Always focus on the user's most recent message. Do not continue or summarize previous conversations
|
||||
- If a voice message contains only noise, silence, or filler sounds, ignore it completely
|
||||
- When a user greets you or starts a new conversation, greet briefly and wait for instructions"""
|
||||
|
||||
server = AgentServer()
|
||||
|
||||
@@ -92,6 +105,13 @@ async def entrypoint(ctx: JobContext):
|
||||
logger.info("Connected to room, local identity: %s", ctx.room.local_participant.identity)
|
||||
logger.info("Remote participants: %s", list(ctx.room.remote_participants.keys()))
|
||||
|
||||
# Publish activity video track (animated waveform bars)
|
||||
activity_video = ActivityVideoPublisher()
|
||||
video_track = rtc.LocalVideoTrack.create_video_track("activity", activity_video.source)
|
||||
await ctx.room.local_participant.publish_track(video_track)
|
||||
activity_task = asyncio.create_task(activity_video.run())
|
||||
logger.info("Activity video track published")
|
||||
|
||||
model = os.environ.get("LITELLM_MODEL", "claude-sonnet")
|
||||
voice_id = os.environ.get("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM")
|
||||
|
||||
@@ -109,6 +129,10 @@ async def entrypoint(ctx: JobContext):
|
||||
vad=ctx.proc.userdata["vad"],
|
||||
)
|
||||
|
||||
@session.on("agent_state_changed")
|
||||
def on_state_changed(ev):
|
||||
activity_video.set_state(ev.new_state)
|
||||
|
||||
@session.on("user_speech_committed")
|
||||
def on_speech(msg):
|
||||
logger.info("USER_SPEECH_COMMITTED: %s", msg.text_content)
|
||||
|
||||
526
article_summary/__init__.py
Normal file
526
article_summary/__init__.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""Blinkist-style article audio summary handler for Matrix bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .state import ArticleState, SessionManager
|
||||
from .extractor import extract_article, detect_topics, is_article_url
|
||||
from .summarizer import summarize_article
|
||||
from .tts import generate_audio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass # Bot type would cause circular import
|
||||
|
||||
logger = logging.getLogger("article-summary")
|
||||
|
||||
# URL regex — matches http/https URLs in message text
|
||||
URL_PATTERN = re.compile(r'https?://[^\s\)>\]"]+')
|
||||
|
||||
CANCEL_WORDS = {"cancel", "stop", "abbrechen", "abbruch", "nevermind"}
|
||||
|
||||
# Keyword sets for robust option matching (substring search, not exact match)
|
||||
_DISCUSS_KW = {"discuss", "diskutieren", "besprechen", "reden", "talk", "chat"}
|
||||
_TEXT_KW = {"text", "zusammenfassung", "summary", "lesen", "read", "schriftlich", "written"}
|
||||
_AUDIO_KW = {"audio", "mp3", "anhören", "vorlesen", "hören", "listen", "blinkist", "abspielen", "podcast"}
|
||||
|
||||
# Simple German detection: common words that appear frequently in German text
|
||||
_DE_INDICATORS = {"der", "die", "das", "und", "ist", "ein", "eine", "für", "mit", "auf", "den", "dem", "sich", "nicht", "von", "wird", "auch", "nach", "wie", "aber"}
|
||||
|
||||
LANGUAGE_OPTIONS = {
|
||||
"1": ("en", "English"),
|
||||
"2": ("de", "German"),
|
||||
"en": ("en", "English"),
|
||||
"de": ("de", "German"),
|
||||
"english": ("en", "English"),
|
||||
"german": ("de", "German"),
|
||||
"deutsch": ("de", "German"),
|
||||
}
|
||||
|
||||
DURATION_OPTIONS = {
|
||||
"1": 5,
|
||||
"2": 10,
|
||||
"3": 15,
|
||||
"5": 5,
|
||||
"10": 10,
|
||||
"15": 15,
|
||||
}
|
||||
|
||||
|
||||
def _detect_content_lang(text: str) -> str:
|
||||
"""Detect language from text content. Returns 'de' or 'en'."""
|
||||
words = set(re.findall(r'\b\w+\b', text.lower()))
|
||||
de_hits = len(words & _DE_INDICATORS)
|
||||
return "de" if de_hits >= 4 else "en"
|
||||
|
||||
|
||||
def _classify_choice(body: str) -> str | None:
|
||||
"""Classify user's action choice from free-form text.
|
||||
|
||||
Returns 'discuss', 'text', 'audio', or None (unrecognized).
|
||||
"""
|
||||
# Normalize: lowercase, strip punctuation around digits
|
||||
raw = body.strip().lower()
|
||||
# Extract bare number if message is just "3." or "3!" or "nummer 3" etc.
|
||||
num_match = re.search(r'\b([123])\b', raw)
|
||||
bare_num = num_match.group(1) if num_match else None
|
||||
|
||||
# Number-only messages (highest priority — unambiguous)
|
||||
stripped = re.sub(r'[^\w\s]', '', raw).strip()
|
||||
if stripped in ("1", "2", "3"):
|
||||
return {"1": "discuss", "2": "text", "3": "audio"}[stripped]
|
||||
|
||||
# Keyword search (substring matching)
|
||||
if any(kw in raw for kw in _AUDIO_KW):
|
||||
return "audio"
|
||||
if any(kw in raw for kw in _TEXT_KW):
|
||||
return "text"
|
||||
if any(kw in raw for kw in _DISCUSS_KW):
|
||||
return "discuss"
|
||||
|
||||
# "nummer 3" / "option 3" / "3. bitte" — number in context
|
||||
if bare_num:
|
||||
return {"1": "discuss", "2": "text", "3": "audio"}[bare_num]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ArticleSummaryHandler:
|
||||
"""Handles the interactive article summary conversation flow."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: AsyncOpenAI,
|
||||
model: str,
|
||||
elevenlabs_key: str,
|
||||
voice_id: str,
|
||||
firecrawl_url: str | None = None,
|
||||
) -> None:
|
||||
self.llm = llm_client
|
||||
self.model = model
|
||||
self.elevenlabs_key = elevenlabs_key
|
||||
self.voice_id = voice_id
|
||||
self.firecrawl_url = firecrawl_url
|
||||
self.sessions = SessionManager()
|
||||
|
||||
async def handle_message(
|
||||
self, room_id: str, sender: str, body: str
|
||||
) -> str | None:
|
||||
"""Process a message through the article summary FSM.
|
||||
|
||||
Returns:
|
||||
- None: Not handled (pass to normal AI handler).
|
||||
- str: Text response to send.
|
||||
- "__GENERATE__": Signal to run the full generation pipeline.
|
||||
"""
|
||||
body_lower = body.strip().lower()
|
||||
session = self.sessions.get(sender, room_id)
|
||||
|
||||
# Cancel from any active state
|
||||
if session.state != ArticleState.IDLE and body_lower in CANCEL_WORDS:
|
||||
ui_de = session.ui_language == "de"
|
||||
self.sessions.reset(sender, room_id)
|
||||
return "Zusammenfassung abgebrochen." if ui_de else "Summary cancelled."
|
||||
|
||||
# Route based on current state
|
||||
if session.state == ArticleState.IDLE:
|
||||
return await self._check_for_url(room_id, sender, body)
|
||||
|
||||
elif session.state == ArticleState.URL_DETECTED:
|
||||
# Waiting for user to pick action (discuss, text summary, audio)
|
||||
return await self._on_action_choice(room_id, sender, body, body_lower)
|
||||
|
||||
elif session.state == ArticleState.AWAITING_LANGUAGE:
|
||||
# Audio flow: waiting for language selection
|
||||
return self._on_language(room_id, sender, body_lower)
|
||||
|
||||
elif session.state == ArticleState.LANGUAGE:
|
||||
# Waiting for duration selection
|
||||
return self._on_duration(room_id, sender, body_lower)
|
||||
|
||||
elif session.state == ArticleState.DURATION:
|
||||
# Waiting for topic selection
|
||||
return self._on_topics(room_id, sender, body)
|
||||
|
||||
elif session.state == ArticleState.GENERATING:
|
||||
if session.ui_language == "de":
|
||||
return "Zusammenfassung wird noch erstellt, bitte warten..."
|
||||
return "Still generating your summary, please wait..."
|
||||
|
||||
elif session.state == ArticleState.COMPLETE:
|
||||
# Follow-up Q&A about the article
|
||||
return await self._on_followup(room_id, sender, body)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_for_url(
|
||||
self, room_id: str, sender: str, body: str
|
||||
) -> str | None:
|
||||
"""Check if message contains an article URL."""
|
||||
urls = URL_PATTERN.findall(body)
|
||||
# Filter to article-like URLs
|
||||
article_urls = [u for u in urls if is_article_url(u)]
|
||||
if not article_urls:
|
||||
return None
|
||||
|
||||
url = article_urls[0]
|
||||
session = self.sessions.get(sender, room_id)
|
||||
|
||||
# Extract article content
|
||||
logger.info("Extracting article from %s", url)
|
||||
article = await extract_article(url, self.firecrawl_url)
|
||||
if not article:
|
||||
return None # Could not extract — let normal handler deal with it
|
||||
|
||||
session.url = url
|
||||
session.title = article["title"]
|
||||
session.content = article["content"]
|
||||
word_count = article["word_count"]
|
||||
read_time = max(1, word_count // 200)
|
||||
|
||||
# Detect topics via LLM
|
||||
session.detected_topics = await detect_topics(
|
||||
article["content"], self.llm, self.model
|
||||
)
|
||||
|
||||
session.state = ArticleState.URL_DETECTED
|
||||
self.sessions.touch(sender, room_id)
|
||||
|
||||
topics_hint = ""
|
||||
if session.detected_topics:
|
||||
topics_hint = f"\nTopics: {', '.join(session.detected_topics)}"
|
||||
|
||||
# Detect article language for localized UI
|
||||
lang = _detect_content_lang(session.content[:2000])
|
||||
session.ui_language = lang
|
||||
|
||||
if lang == "de":
|
||||
return (
|
||||
f"**Gefunden:** {session.title} (~{read_time} min Lesezeit){topics_hint}\n\n"
|
||||
f"Was möchtest du damit machen?\n"
|
||||
f"1\ufe0f\u20e3 **Diskutieren** \u2014 Ich lese den Artikel und wir reden darüber\n"
|
||||
f"2\ufe0f\u20e3 **Textzusammenfassung** \u2014 Kurze schriftliche Zusammenfassung\n"
|
||||
f"3\ufe0f\u20e3 **Audiozusammenfassung** \u2014 Blinkist-Style MP3\n\n"
|
||||
f"_(oder schreib einfach weiter \u2014 ich unterbreche nicht)_"
|
||||
)
|
||||
return (
|
||||
f"**Found:** {session.title} (~{read_time} min read){topics_hint}\n\n"
|
||||
f"What would you like to do?\n"
|
||||
f"1\ufe0f\u20e3 **Discuss** \u2014 I'll read the article and we can talk about it\n"
|
||||
f"2\ufe0f\u20e3 **Text summary** \u2014 Quick written summary\n"
|
||||
f"3\ufe0f\u20e3 **Audio summary** \u2014 Blinkist-style MP3\n\n"
|
||||
f"_(or just keep chatting \u2014 I won't interrupt)_"
|
||||
)
|
||||
|
||||
def _on_language(
|
||||
self, room_id: str, sender: str, choice: str
|
||||
) -> str | None:
|
||||
"""Handle language selection."""
|
||||
lang = LANGUAGE_OPTIONS.get(choice)
|
||||
session = self.sessions.get(sender, room_id)
|
||||
ui_de = session.ui_language == "de"
|
||||
if not lang:
|
||||
if ui_de:
|
||||
return "Bitte wähle eine Sprache: **1** für Englisch, **2** für Deutsch."
|
||||
return "Please pick a language: **1** for English, **2** for German."
|
||||
|
||||
session.language = lang[0]
|
||||
session.state = ArticleState.LANGUAGE
|
||||
self.sessions.touch(sender, room_id)
|
||||
|
||||
if ui_de:
|
||||
return (
|
||||
f"Sprache: **{lang[1]}**. Wie lang soll die Zusammenfassung sein?\n"
|
||||
f"1️⃣ 5 Min (kurz)\n"
|
||||
f"2️⃣ 10 Min (standard)\n"
|
||||
f"3️⃣ 15 Min (ausführlich)"
|
||||
)
|
||||
return (
|
||||
f"Language: **{lang[1]}**. How long should the summary be?\n"
|
||||
f"1️⃣ 5 min (short)\n"
|
||||
f"2️⃣ 10 min (standard)\n"
|
||||
f"3️⃣ 15 min (detailed)"
|
||||
)
|
||||
|
||||
def _on_duration(
|
||||
self, room_id: str, sender: str, choice: str
|
||||
) -> str | None:
|
||||
"""Handle duration selection."""
|
||||
duration = DURATION_OPTIONS.get(choice)
|
||||
session = self.sessions.get(sender, room_id)
|
||||
ui_de = session.ui_language == "de"
|
||||
if not duration:
|
||||
if ui_de:
|
||||
return "Bitte wähle: **1** (5 Min), **2** (10 Min) oder **3** (15 Min)."
|
||||
return "Please pick: **1** (5 min), **2** (10 min), or **3** (15 min)."
|
||||
|
||||
session.duration_minutes = duration
|
||||
session.state = ArticleState.DURATION
|
||||
self.sessions.touch(sender, room_id)
|
||||
|
||||
if session.detected_topics:
|
||||
topic_list = "\n".join(
|
||||
f" • {t}" for t in session.detected_topics
|
||||
)
|
||||
if ui_de:
|
||||
return (
|
||||
f"Dauer: **{duration} Min**. Auf welche Themen fokussieren?\n"
|
||||
f"{topic_list}\n\n"
|
||||
f"Antworte mit Themennummern (kommagetrennt), bestimmten Themen oder **alle**."
|
||||
)
|
||||
return (
|
||||
f"Duration: **{duration} min**. Focus on which topics?\n"
|
||||
f"{topic_list}\n\n"
|
||||
f"Reply with topic numbers (comma-separated), specific topics, or **all**."
|
||||
)
|
||||
else:
|
||||
if ui_de:
|
||||
return (
|
||||
f"Dauer: **{duration} Min**. Bestimmte Themen im Fokus?\n"
|
||||
f"Antworte mit Themen (kommagetrennt) oder **alle** für eine allgemeine Zusammenfassung."
|
||||
)
|
||||
return (
|
||||
f"Duration: **{duration} min**. Any specific topics to focus on?\n"
|
||||
f"Reply with topics (comma-separated) or **all** for a general summary."
|
||||
)
|
||||
|
||||
def _on_topics(
|
||||
self, room_id: str, sender: str, body: str
|
||||
) -> str | None:
|
||||
"""Handle topic selection. Returns __GENERATE__ to trigger pipeline."""
|
||||
session = self.sessions.get(sender, room_id)
|
||||
body_lower = body.strip().lower()
|
||||
|
||||
if body_lower in ("all", "alle", "everything", "alles"):
|
||||
session.topics = session.detected_topics or []
|
||||
else:
|
||||
# Try to match by number
|
||||
parts = re.split(r'[,\s]+', body.strip())
|
||||
selected = []
|
||||
for p in parts:
|
||||
p = p.strip()
|
||||
if p.isdigit():
|
||||
idx = int(p) - 1
|
||||
if 0 <= idx < len(session.detected_topics):
|
||||
selected.append(session.detected_topics[idx])
|
||||
elif p:
|
||||
selected.append(p)
|
||||
session.topics = selected or session.detected_topics or []
|
||||
|
||||
session.state = ArticleState.GENERATING
|
||||
self.sessions.touch(sender, room_id)
|
||||
return "__GENERATE__"
|
||||
|
||||
async def _on_action_choice(
|
||||
self, room_id: str, sender: str, body: str, body_lower: str
|
||||
) -> str | None:
|
||||
"""Handle user's choice after URL detection: discuss, text summary, or audio."""
|
||||
session = self.sessions.get(sender, room_id)
|
||||
choice = _classify_choice(body)
|
||||
|
||||
if choice == "discuss":
|
||||
article_context = session.content[:8000]
|
||||
title = session.title
|
||||
self.sessions.reset(sender, room_id)
|
||||
return f"__DISCUSS__{title}\n{article_context}"
|
||||
|
||||
if choice == "text":
|
||||
return await self._generate_text_summary(room_id, sender)
|
||||
|
||||
if choice == "audio":
|
||||
return self._prompt_language(room_id, sender)
|
||||
|
||||
# Unrecognized — user is just chatting, pass through with article context
|
||||
article_context = session.content[:8000]
|
||||
title = session.title
|
||||
self.sessions.reset(sender, room_id)
|
||||
return f"__DISCUSS__{title}\n{article_context}"
|
||||
|
||||
def _prompt_language(self, room_id: str, sender: str) -> str:
|
||||
"""Present language selection for audio summary."""
|
||||
session = self.sessions.get(sender, room_id)
|
||||
session.state = ArticleState.AWAITING_LANGUAGE
|
||||
self.sessions.touch(sender, room_id)
|
||||
if session.ui_language == "de":
|
||||
return (
|
||||
"In welcher Sprache soll die Audiozusammenfassung sein?\n"
|
||||
"1\ufe0f\u20e3 Englisch\n"
|
||||
"2\ufe0f\u20e3 Deutsch"
|
||||
)
|
||||
return (
|
||||
"What language for the audio summary?\n"
|
||||
"1\ufe0f\u20e3 English\n"
|
||||
"2\ufe0f\u20e3 German"
|
||||
)
|
||||
|
||||
async def _generate_text_summary(self, room_id: str, sender: str) -> str | None:
|
||||
"""Generate a text-only summary of the article."""
|
||||
session = self.sessions.get(sender, room_id)
|
||||
try:
|
||||
resp = await self.llm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Summarize this article concisely in 3-5 paragraphs. "
|
||||
"Respond in the same language as the article."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Article: {session.title}\n\n{session.content[:12000]}",
|
||||
},
|
||||
],
|
||||
max_tokens=1000,
|
||||
temperature=0.3,
|
||||
)
|
||||
summary = resp.choices[0].message.content.strip()
|
||||
session.summary_text = summary
|
||||
session.state = ArticleState.COMPLETE
|
||||
self.sessions.touch(sender, room_id)
|
||||
if session.ui_language == "de":
|
||||
return (
|
||||
f"**Zusammenfassung: {session.title}**\n\n{summary}\n\n"
|
||||
f"_Stelle Folgefragen oder teile einen neuen Link._"
|
||||
)
|
||||
return (
|
||||
f"**Summary: {session.title}**\n\n{summary}\n\n"
|
||||
f"_Ask follow-up questions or share a new link._"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Text summary failed", exc_info=True)
|
||||
self.sessions.reset(sender, room_id)
|
||||
return None
|
||||
|
||||
async def generate_and_post(self, bot, room_id: str, sender: str) -> None:
|
||||
"""Run the full pipeline: summarize → TTS → upload MP3."""
|
||||
session = self.sessions.get(sender, room_id)
|
||||
|
||||
ui_de = session.ui_language == "de"
|
||||
topics_str = ", ".join(session.topics) if session.topics else ("alle Themen" if ui_de else "all topics")
|
||||
if ui_de:
|
||||
await bot._send_text(
|
||||
room_id,
|
||||
f"Erstelle {session.duration_minutes}-Min {session.language.upper()} "
|
||||
f"Zusammenfassung von **{session.title}** (Fokus: {topics_str})...",
|
||||
)
|
||||
else:
|
||||
await bot._send_text(
|
||||
room_id,
|
||||
f"Generating {session.duration_minutes}-min {session.language.upper()} "
|
||||
f"summary of **{session.title}** (focus: {topics_str})...",
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Summarize
|
||||
summary = await summarize_article(
|
||||
content=session.content,
|
||||
language=session.language,
|
||||
duration_minutes=session.duration_minutes,
|
||||
topics=session.topics,
|
||||
llm_client=self.llm,
|
||||
model=self.model,
|
||||
)
|
||||
session.summary_text = summary
|
||||
|
||||
# Step 2: TTS
|
||||
mp3_bytes, duration_secs = await generate_audio(
|
||||
text=summary,
|
||||
api_key=self.elevenlabs_key,
|
||||
voice_id=self.voice_id,
|
||||
language=session.language,
|
||||
)
|
||||
|
||||
# Step 3: Upload and send audio
|
||||
filename = re.sub(r'[^\w\s-]', '', session.title)[:50].strip()
|
||||
filename = f"{filename}.mp3" if filename else "summary.mp3"
|
||||
|
||||
await bot._send_audio(room_id, mp3_bytes, filename, duration_secs)
|
||||
|
||||
# Step 4: Send transcript
|
||||
transcript_preview = summary[:500]
|
||||
if len(summary) > 500:
|
||||
transcript_preview += "..."
|
||||
if ui_de:
|
||||
await bot._send_text(
|
||||
room_id,
|
||||
f"**Zusammenfassung von:** {session.title}\n\n{transcript_preview}\n\n"
|
||||
f"_Du kannst Folgefragen zu diesem Artikel stellen._",
|
||||
)
|
||||
else:
|
||||
await bot._send_text(
|
||||
room_id,
|
||||
f"**Summary of:** {session.title}\n\n{transcript_preview}\n\n"
|
||||
f"_You can ask follow-up questions about this article._",
|
||||
)
|
||||
|
||||
session.state = ArticleState.COMPLETE
|
||||
self.sessions.touch(sender, room_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Article summary pipeline failed for %s", session.url)
|
||||
if ui_de:
|
||||
await bot._send_text(
|
||||
room_id, "Entschuldigung, die Audiozusammenfassung konnte nicht erstellt werden. Bitte versuche es erneut."
|
||||
)
|
||||
else:
|
||||
await bot._send_text(
|
||||
room_id, "Sorry, I couldn't generate the audio summary. Please try again."
|
||||
)
|
||||
self.sessions.reset(sender, room_id)
|
||||
|
||||
async def _on_followup(
|
||||
self, room_id: str, sender: str, body: str
|
||||
) -> str | None:
|
||||
"""Answer follow-up questions about the summarized article."""
|
||||
session = self.sessions.get(sender, room_id)
|
||||
|
||||
# If user posts a new URL, start fresh
|
||||
urls = URL_PATTERN.findall(body)
|
||||
if any(is_article_url(u) for u in urls):
|
||||
self.sessions.reset(sender, room_id)
|
||||
return await self._check_for_url(room_id, sender, body)
|
||||
|
||||
# Check if it looks like a question about the article
|
||||
question_indicators = ["?", "what", "how", "why", "explain", "was", "wie", "warum", "erkläre"]
|
||||
is_question = any(q in body.lower() for q in question_indicators)
|
||||
if not is_question:
|
||||
# Not a question — reset and let normal handler take over
|
||||
self.sessions.reset(sender, room_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self.llm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are answering follow-up questions about an article. "
|
||||
"Use the article content below to answer. Be concise. "
|
||||
"Respond in the same language as the question."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Article: {session.title}\n\n"
|
||||
f"{session.content[:8000]}\n\n"
|
||||
f"Summary: {session.summary_text[:3000]}\n\n"
|
||||
f"Question: {body}"
|
||||
),
|
||||
},
|
||||
],
|
||||
max_tokens=500,
|
||||
temperature=0.5,
|
||||
)
|
||||
return resp.choices[0].message.content.strip()
|
||||
except Exception:
|
||||
logger.warning("Follow-up Q&A failed", exc_info=True)
|
||||
self.sessions.reset(sender, room_id)
|
||||
return None
|
||||
146
article_summary/extractor.py
Normal file
146
article_summary/extractor.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Article content extraction via Firecrawl with BeautifulSoup fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
logger = logging.getLogger("article-summary.extractor")
|
||||
|
||||
MAX_CONTENT_CHARS = 15_000
|
||||
|
||||
# Domains that are not articles (social media, file hosts, etc.)
|
||||
NON_ARTICLE_DOMAINS = {
|
||||
"youtube.com", "youtu.be", "twitter.com", "x.com", "instagram.com",
|
||||
"facebook.com", "tiktok.com", "reddit.com", "discord.com",
|
||||
"drive.google.com", "docs.google.com", "github.com",
|
||||
}
|
||||
|
||||
|
||||
def is_article_url(url: str) -> bool:
|
||||
"""Check if URL is likely an article (not social media, files, etc.)."""
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
host = urlparse(url).hostname or ""
|
||||
host = host.removeprefix("www.")
|
||||
return host not in NON_ARTICLE_DOMAINS
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def extract_article(url: str, firecrawl_url: str | None = None) -> dict | None:
|
||||
"""Extract article content from URL.
|
||||
|
||||
Returns dict with: title, content, word_count, detected_topics, language_hint
|
||||
Returns None if extraction fails.
|
||||
"""
|
||||
title = ""
|
||||
content = ""
|
||||
|
||||
# Try Firecrawl first
|
||||
if firecrawl_url:
|
||||
try:
|
||||
result = await _firecrawl_extract(url, firecrawl_url)
|
||||
if result:
|
||||
title, content = result
|
||||
except Exception:
|
||||
logger.warning("Firecrawl extraction failed for %s", url, exc_info=True)
|
||||
|
||||
# Fallback to BeautifulSoup
|
||||
if not content:
|
||||
try:
|
||||
result = await _bs4_extract(url)
|
||||
if result:
|
||||
title, content = result
|
||||
except Exception:
|
||||
logger.warning("BS4 extraction failed for %s", url, exc_info=True)
|
||||
|
||||
if not content:
|
||||
return None
|
||||
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
word_count = len(content.split())
|
||||
|
||||
return {
|
||||
"title": title or url,
|
||||
"content": content,
|
||||
"word_count": word_count,
|
||||
}
|
||||
|
||||
|
||||
async def _firecrawl_extract(url: str, firecrawl_url: str) -> tuple[str, str] | None:
|
||||
"""Extract via Firecrawl API."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{firecrawl_url}/v1/scrape",
|
||||
json={"url": url, "formats": ["markdown"]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
doc = data.get("data", {})
|
||||
title = doc.get("metadata", {}).get("title", "")
|
||||
content = doc.get("markdown", "")
|
||||
if not content:
|
||||
return None
|
||||
return title, content
|
||||
|
||||
|
||||
async def _bs4_extract(url: str) -> tuple[str, str] | None:
|
||||
"""Fallback extraction via httpx + BeautifulSoup."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (compatible; ArticleSummaryBot/1.0)",
|
||||
"Accept": "text/html",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=20.0, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
|
||||
# Extract title
|
||||
title = ""
|
||||
if soup.title:
|
||||
title = soup.title.get_text(strip=True)
|
||||
|
||||
# Remove script/style/nav elements
|
||||
for tag in soup(["script", "style", "nav", "header", "footer", "aside", "form"]):
|
||||
tag.decompose()
|
||||
|
||||
# Try <article> tag first, then <main>, then body
|
||||
article = soup.find("article") or soup.find("main") or soup.find("body")
|
||||
if not article:
|
||||
return None
|
||||
|
||||
# Get text, clean up whitespace
|
||||
text = article.get_text(separator="\n", strip=True)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
|
||||
if len(text) < 100:
|
||||
return None
|
||||
|
||||
return title, text
|
||||
|
||||
|
||||
async def detect_topics(content: str, llm_client, model: str) -> list[str]:
|
||||
"""Use LLM to detect 3-5 key topics from article content."""
|
||||
snippet = content[:2000]
|
||||
try:
|
||||
resp = await llm_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "Extract 3-5 key topics from this article. Return ONLY a comma-separated list of short topic labels (2-4 words each). No numbering, no explanation."},
|
||||
{"role": "user", "content": snippet},
|
||||
],
|
||||
max_tokens=100,
|
||||
temperature=0.3,
|
||||
)
|
||||
raw = resp.choices[0].message.content.strip()
|
||||
topics = [t.strip() for t in raw.split(",") if t.strip()]
|
||||
return topics[:5]
|
||||
except Exception:
|
||||
logger.warning("Topic detection failed", exc_info=True)
|
||||
return []
|
||||
62
article_summary/state.py
Normal file
62
article_summary/state.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Per-user FSM state machine for article summary conversations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class ArticleState(Enum):
|
||||
IDLE = auto()
|
||||
URL_DETECTED = auto()
|
||||
AWAITING_LANGUAGE = auto() # Audio flow: waiting for language selection
|
||||
LANGUAGE = auto()
|
||||
DURATION = auto()
|
||||
TOPICS = auto()
|
||||
GENERATING = auto()
|
||||
COMPLETE = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArticleSession:
|
||||
state: ArticleState = ArticleState.IDLE
|
||||
url: str = ""
|
||||
title: str = ""
|
||||
content: str = ""
|
||||
language: str = "" # Audio output language (set by user choice)
|
||||
ui_language: str = "en" # Detected from article content, used for UI strings
|
||||
duration_minutes: int = 10
|
||||
topics: list[str] = field(default_factory=list)
|
||||
detected_topics: list[str] = field(default_factory=list)
|
||||
summary_text: str = ""
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
STATE_TIMEOUT = 300 # 5 minutes
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manage per-(user, room) article summary sessions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[tuple[str, str], ArticleSession] = {}
|
||||
|
||||
def get(self, user_id: str, room_id: str) -> ArticleSession:
|
||||
key = (user_id, room_id)
|
||||
session = self._sessions.get(key)
|
||||
if session and time.time() - session.timestamp > STATE_TIMEOUT:
|
||||
session = None
|
||||
self._sessions.pop(key, None)
|
||||
if session is None:
|
||||
session = ArticleSession()
|
||||
self._sessions[key] = session
|
||||
return session
|
||||
|
||||
def reset(self, user_id: str, room_id: str) -> None:
|
||||
self._sessions.pop((user_id, room_id), None)
|
||||
|
||||
def touch(self, user_id: str, room_id: str) -> None:
|
||||
key = (user_id, room_id)
|
||||
if key in self._sessions:
|
||||
self._sessions[key].timestamp = time.time()
|
||||
68
article_summary/summarizer.py
Normal file
68
article_summary/summarizer.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""LLM-powered article summarization with personalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger("article-summary.summarizer")
|
||||
|
||||
WORDS_PER_MINUTE = 150 # Clear narration pace
|
||||
|
||||
|
||||
async def summarize_article(
|
||||
content: str,
|
||||
language: str,
|
||||
duration_minutes: int,
|
||||
topics: list[str],
|
||||
llm_client: AsyncOpenAI,
|
||||
model: str,
|
||||
) -> str:
|
||||
"""Generate a narrative summary of article content.
|
||||
|
||||
Args:
|
||||
content: Article text (max ~15K chars).
|
||||
language: Target language ("en" or "de").
|
||||
duration_minutes: Target audio duration (5, 10, or 15).
|
||||
topics: Focus topics selected by user.
|
||||
llm_client: AsyncOpenAI instance (LiteLLM).
|
||||
model: Model name to use.
|
||||
|
||||
Returns:
|
||||
Summary text ready for TTS.
|
||||
"""
|
||||
word_target = duration_minutes * WORDS_PER_MINUTE
|
||||
lang_name = "German" if language == "de" else "English"
|
||||
topics_str = ", ".join(topics) if topics else "all topics"
|
||||
|
||||
system_prompt = f"""You are a professional audio narrator creating a Blinkist-style summary.
|
||||
|
||||
RULES:
|
||||
- Write in {lang_name}.
|
||||
- Target approximately {word_target} words (for a {duration_minutes}-minute audio).
|
||||
- Focus on: {topics_str}.
|
||||
- Use a conversational, engaging narrator tone — as if explaining to a curious friend.
|
||||
- Structure: brief intro → key insights → practical takeaways → brief conclusion.
|
||||
- Use flowing prose, NOT bullet points or lists.
|
||||
- Do NOT include any formatting markers, headers, or markdown.
|
||||
- Do NOT say "In this article..." — jump straight into the content.
|
||||
- Make it sound natural when read aloud."""
|
||||
|
||||
# Truncate very long content
|
||||
if len(content) > 12_000:
|
||||
content = content[:12_000] + "\n\n[Article continues...]"
|
||||
|
||||
resp = await llm_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"Summarize this article:\n\n{content}"},
|
||||
],
|
||||
max_tokens=word_target * 2, # tokens ≈ 1.3x words, with headroom
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
summary = resp.choices[0].message.content.strip()
|
||||
logger.info("Generated summary: %d words (target: %d)", len(summary.split()), word_target)
|
||||
return summary
|
||||
108
article_summary/tts.py
Normal file
108
article_summary/tts.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""ElevenLabs TTS — direct API calls to generate MP3 audio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("article-summary.tts")
|
||||
|
||||
ELEVENLABS_API = "https://api.elevenlabs.io/v1"
|
||||
CHUNK_SIZE = 5000 # Max chars per TTS request
|
||||
|
||||
|
||||
async def generate_audio(
|
||||
text: str,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
language: str = "en",
|
||||
) -> tuple[bytes, float]:
|
||||
"""Generate MP3 audio from text via ElevenLabs API.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech.
|
||||
api_key: ElevenLabs API key.
|
||||
voice_id: ElevenLabs voice ID.
|
||||
language: Language code ("en" or "de").
|
||||
|
||||
Returns:
|
||||
Tuple of (mp3_bytes, estimated_duration_seconds).
|
||||
"""
|
||||
chunks = _split_text(text, CHUNK_SIZE)
|
||||
mp3_parts: list[bytes] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info("Generating TTS chunk %d/%d (%d chars)", i + 1, len(chunks), len(chunk))
|
||||
mp3_data = await _tts_request(chunk, api_key, voice_id, language)
|
||||
mp3_parts.append(mp3_data)
|
||||
|
||||
combined = b"".join(mp3_parts)
|
||||
|
||||
# Estimate duration: ~150 words per minute
|
||||
word_count = len(text.split())
|
||||
est_duration = (word_count / 150) * 60
|
||||
|
||||
logger.info("TTS complete: %d bytes, ~%.0fs estimated", len(combined), est_duration)
|
||||
return combined, est_duration
|
||||
|
||||
|
||||
async def _tts_request(
|
||||
text: str,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
) -> bytes:
|
||||
"""Single TTS API call."""
|
||||
url = f"{ELEVENLABS_API}/text-to-speech/{voice_id}"
|
||||
headers = {
|
||||
"xi-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
},
|
||||
}
|
||||
# Add language hint for non-English
|
||||
if language == "de":
|
||||
payload["language_code"] = "de"
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
|
||||
def _split_text(text: str, max_chars: int) -> list[str]:
|
||||
"""Split text at sentence boundaries for TTS chunking."""
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks: list[str] = []
|
||||
current = ""
|
||||
|
||||
for sentence in _sentence_split(text):
|
||||
if len(current) + len(sentence) > max_chars and current:
|
||||
chunks.append(current.strip())
|
||||
current = sentence
|
||||
else:
|
||||
current += sentence
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks or [text[:max_chars]]
|
||||
|
||||
|
||||
def _sentence_split(text: str) -> list[str]:
|
||||
"""Split text into sentences, keeping delimiters attached."""
|
||||
import re
|
||||
parts = re.split(r'(?<=[.!?])\s+', text)
|
||||
# Re-add trailing space for joining
|
||||
return [p + " " for p in parts]
|
||||
1
confluence-collab
Submodule
1
confluence-collab
Submodule
Submodule confluence-collab added at c4238974a7
3
cron/__init__.py
Normal file
3
cron/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .scheduler import CronScheduler
|
||||
|
||||
__all__ = ["CronScheduler"]
|
||||
145
cron/brave_search.py
Normal file
145
cron/brave_search.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Brave Search executor for cron jobs with optional LLM filtering."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from .formatter import format_search_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BRAVE_API_KEY = os.environ.get("BRAVE_API_KEY", "")
|
||||
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "")
|
||||
FILTER_MODEL = os.environ.get("BASE_MODEL", "claude-haiku")
|
||||
|
||||
FILTER_SYSTEM_PROMPT = """You are a strict search result filter. Given search results and filtering criteria, evaluate each result's title and description against ALL criteria. A result must clearly match EVERY criterion to be included. When in doubt, EXCLUDE.
|
||||
|
||||
Rules:
|
||||
- If a result is a general article, category page, or ad rather than a specific listing, EXCLUDE it.
|
||||
- If the location/region cannot be confirmed from the title or description, EXCLUDE it.
|
||||
- If any single criterion is not met or unclear, EXCLUDE the result.
|
||||
|
||||
Return ONLY a JSON array of matching indices (0-based). If none match, return [].
|
||||
No explanation, just the array."""
|
||||
|
||||
|
||||
async def _llm_filter(results: list[dict], criteria: str) -> list[dict]:
|
||||
"""Use LLM to filter search results against user-defined criteria."""
|
||||
if not LITELLM_URL or not LITELLM_KEY:
|
||||
logger.warning("LLM not configured, skipping filter")
|
||||
return results
|
||||
|
||||
# Build a concise representation of results for the LLM
|
||||
result_descriptions = []
|
||||
for i, r in enumerate(results):
|
||||
title = r.get("title", "")
|
||||
desc = r.get("description", "")
|
||||
url = r.get("url", "")
|
||||
result_descriptions.append(f"[{i}] {title} — {desc} ({url})")
|
||||
|
||||
user_msg = (
|
||||
f"**Criteria:** {criteria}\n\n"
|
||||
f"**Results:**\n" + "\n".join(result_descriptions)
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LITELLM_URL}/chat/completions",
|
||||
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
||||
json={
|
||||
"model": FILTER_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": FILTER_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_msg},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 200,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
reply = data["choices"][0]["message"]["content"].strip()
|
||||
# Extract JSON array from response (LLM may include extra text)
|
||||
import re
|
||||
match = re.search(r"\[[\d,\s]*\]", reply)
|
||||
if not match:
|
||||
logger.warning("LLM filter returned no array: %s", reply)
|
||||
return results
|
||||
indices = json.loads(match.group())
|
||||
if not isinstance(indices, list):
|
||||
logger.warning("LLM filter returned non-list: %s", reply)
|
||||
return results
|
||||
|
||||
filtered = [results[i] for i in indices if 0 <= i < len(results)]
|
||||
logger.info(
|
||||
"LLM filter: %d/%d results matched criteria",
|
||||
len(filtered), len(results),
|
||||
)
|
||||
return filtered
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("LLM filter failed, returning all results: %s", exc)
|
||||
return results
|
||||
|
||||
|
||||
async def execute_brave_search(job: dict, send_text, **_kwargs) -> dict:
|
||||
"""Run a Brave Search query, dedup, optionally LLM-filter, post to Matrix."""
|
||||
if not BRAVE_API_KEY:
|
||||
return {"status": "error", "error": "BRAVE_API_KEY not configured"}
|
||||
|
||||
config = job.get("config", {})
|
||||
query = config.get("query", "")
|
||||
criteria = config.get("criteria", "")
|
||||
max_results = config.get("maxResults", 10)
|
||||
target_room = job["targetRoom"]
|
||||
dedup_keys = set(job.get("dedupKeys", []))
|
||||
|
||||
if not query:
|
||||
return {"status": "error", "error": "No search query configured"}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"X-Subscription-Token": BRAVE_API_KEY,
|
||||
},
|
||||
params={"q": query, "count": max_results, "text_decorations": False},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = data.get("web", {}).get("results", [])
|
||||
if not results:
|
||||
return {"status": "no_results"}
|
||||
|
||||
# Dedup by URL
|
||||
new_results = [r for r in results if r.get("url") not in dedup_keys]
|
||||
|
||||
if not new_results:
|
||||
return {"status": "no_results"}
|
||||
|
||||
# LLM filter if criteria provided
|
||||
if criteria:
|
||||
new_results = await _llm_filter(new_results, criteria)
|
||||
if not new_results:
|
||||
return {"status": "no_results"}
|
||||
|
||||
msg = format_search_results(job["name"], new_results)
|
||||
await send_text(target_room, msg)
|
||||
|
||||
new_keys = [r["url"] for r in new_results if r.get("url")]
|
||||
return {
|
||||
"status": "success",
|
||||
"newDedupKeys": new_keys,
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Brave search cron failed: %s", exc, exc_info=True)
|
||||
return {"status": "error", "error": str(exc)}
|
||||
192
cron/browser_executor.py
Normal file
192
cron/browser_executor.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Browser scrape executor — dispatches jobs to Skyvern API."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SKYVERN_BASE_URL = os.environ.get("SKYVERN_BASE_URL", "http://skyvern:8000")
|
||||
SKYVERN_API_KEY = os.environ.get("SKYVERN_API_KEY", "")
|
||||
|
||||
POLL_INTERVAL = 5 # seconds
|
||||
MAX_POLL_TIME = 300 # 5 minutes
|
||||
|
||||
|
||||
async def _create_task(url: str, goal: str, extraction_goal: str = "",
|
||||
extraction_schema: dict | None = None,
|
||||
credential_id: str | None = None, totp_identifier: str | None = None) -> str:
|
||||
"""Create a Skyvern task and return the task_id."""
|
||||
payload: dict = {
|
||||
"url": url,
|
||||
"navigation_goal": goal,
|
||||
"data_extraction_goal": extraction_goal or goal,
|
||||
}
|
||||
if extraction_schema:
|
||||
payload["extracted_information_schema"] = extraction_schema
|
||||
if credential_id:
|
||||
payload["credential_id"] = credential_id
|
||||
if totp_identifier:
|
||||
payload["totp_identifier"] = totp_identifier
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(
|
||||
f"{SKYVERN_BASE_URL}/api/v1/tasks",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": SKYVERN_API_KEY,
|
||||
},
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["task_id"]
|
||||
|
||||
|
||||
async def _poll_task(run_id: str) -> dict:
|
||||
"""Poll Skyvern until task completes or times out."""
|
||||
elapsed = 0
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
while elapsed < MAX_POLL_TIME:
|
||||
resp = await client.get(
|
||||
f"{SKYVERN_BASE_URL}/api/v1/tasks/{run_id}",
|
||||
headers={"x-api-key": SKYVERN_API_KEY},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
status = data.get("status", "")
|
||||
|
||||
if status in ("completed", "failed", "terminated", "timed_out"):
|
||||
return data
|
||||
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
elapsed += POLL_INTERVAL
|
||||
|
||||
return {"status": "timed_out", "error": f"Polling exceeded {MAX_POLL_TIME}s"}
|
||||
|
||||
|
||||
def _format_extraction(data: dict) -> str:
|
||||
"""Format extracted data as readable markdown for Matrix."""
|
||||
extracted = data.get("extracted_information") or data.get("extracted_data")
|
||||
if not extracted:
|
||||
return "No data extracted."
|
||||
|
||||
# Handle list of items (most common: news, listings, results)
|
||||
items = None
|
||||
if isinstance(extracted, list):
|
||||
items = extracted
|
||||
elif isinstance(extracted, dict):
|
||||
# Look for the first list value in the dict (e.g. {"news": [...]})
|
||||
for v in extracted.values():
|
||||
if isinstance(v, list) and v:
|
||||
items = v
|
||||
break
|
||||
|
||||
if items and isinstance(items[0], dict):
|
||||
lines = []
|
||||
for item in items:
|
||||
# Try common field names for title/link
|
||||
title = item.get("title") or item.get("name") or item.get("headline") or ""
|
||||
link = item.get("link") or item.get("url") or item.get("href") or ""
|
||||
# Build a line with remaining fields as details
|
||||
skip = {"title", "name", "headline", "link", "url", "href"}
|
||||
details = " · ".join(
|
||||
str(v) for k, v in item.items()
|
||||
if k not in skip and v
|
||||
)
|
||||
if title and link:
|
||||
line = f"- [{title}]({link})"
|
||||
elif title:
|
||||
line = f"- {title}"
|
||||
else:
|
||||
line = f"- {json.dumps(item, ensure_ascii=False)}"
|
||||
if details:
|
||||
line += f" \n {details}"
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
# Fallback: compact JSON
|
||||
if isinstance(extracted, (dict, list)):
|
||||
return json.dumps(extracted, indent=2, ensure_ascii=False)
|
||||
return str(extracted)
|
||||
|
||||
|
||||
async def execute_browser_scrape(job: dict, send_text, **_kwargs) -> dict:
|
||||
"""Execute a browser-based scraping job via Skyvern."""
|
||||
target_room = job["targetRoom"]
|
||||
config = job.get("config", {})
|
||||
url = config.get("url", "")
|
||||
goal = config.get("goal", config.get("query", f"Scrape content from {url}"))
|
||||
extraction_goal = config.get("extractionGoal", "") or goal
|
||||
extraction_schema = config.get("extractionSchema")
|
||||
browser_profile = job.get("browserProfile")
|
||||
|
||||
if not url:
|
||||
await send_text(target_room, f"**{job['name']}**: No URL configured.")
|
||||
return {"status": "error", "error": "No URL configured"}
|
||||
|
||||
if not SKYVERN_API_KEY:
|
||||
await send_text(
|
||||
target_room,
|
||||
f"**{job['name']}**: Browser automation not configured (missing API key).",
|
||||
)
|
||||
return {"status": "error", "error": "SKYVERN_API_KEY not set"}
|
||||
|
||||
# Map browser profile fields to Skyvern credential
|
||||
credential_id = None
|
||||
totp_identifier = None
|
||||
if browser_profile:
|
||||
if browser_profile.get("status") == "expired":
|
||||
await send_text(
|
||||
target_room,
|
||||
f"**{job['name']}**: Browser credential expired. "
|
||||
f"Update at https://matrixhost.eu/settings/automations",
|
||||
)
|
||||
return {"status": "error", "error": "Browser credential expired"}
|
||||
credential_id = browser_profile.get("credentialId")
|
||||
totp_identifier = browser_profile.get("totpIdentifier")
|
||||
|
||||
try:
|
||||
run_id = await _create_task(
|
||||
url=url,
|
||||
goal=goal,
|
||||
extraction_goal=extraction_goal,
|
||||
extraction_schema=extraction_schema,
|
||||
credential_id=credential_id,
|
||||
totp_identifier=totp_identifier,
|
||||
)
|
||||
logger.info("Skyvern task created: %s for job %s", run_id, job["name"])
|
||||
|
||||
result = await _poll_task(run_id)
|
||||
status = result.get("status", "unknown")
|
||||
|
||||
if status == "completed":
|
||||
extracted = _format_extraction(result)
|
||||
msg = f"**{job['name']}** — {url}\n\n{extracted}"
|
||||
# Truncate if too long for Matrix
|
||||
if len(msg) > 4000:
|
||||
msg = msg[:3950] + "\n\n_(truncated)_"
|
||||
await send_text(target_room, msg)
|
||||
return {"status": "success"}
|
||||
else:
|
||||
error = result.get("error") or result.get("failure_reason") or status
|
||||
await send_text(
|
||||
target_room,
|
||||
f"**{job['name']}**: Browser task {status} — {error}",
|
||||
)
|
||||
return {"status": "error", "error": str(error)}
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
error_msg = f"Skyvern API error: {exc.response.status_code}"
|
||||
logger.error("Browser executor failed: %s", error_msg, exc_info=True)
|
||||
await send_text(target_room, f"**{job['name']}**: {error_msg}")
|
||||
return {"status": "error", "error": error_msg}
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.error("Browser executor failed: %s", error_msg, exc_info=True)
|
||||
await send_text(target_room, f"**{job['name']}**: Browser task failed — {error_msg}")
|
||||
return {"status": "error", "error": error_msg}
|
||||
28
cron/executor.py
Normal file
28
cron/executor.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Dispatch cron jobs to the correct executor by job_type."""
|
||||
|
||||
import logging
|
||||
|
||||
from .brave_search import execute_brave_search
|
||||
from .browser_executor import execute_browser_scrape
|
||||
from .reminder import execute_reminder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EXECUTORS = {
|
||||
"brave_search": execute_brave_search,
|
||||
"browser_scrape": execute_browser_scrape,
|
||||
"reminder": execute_reminder,
|
||||
}
|
||||
|
||||
|
||||
async def execute_job(job: dict, send_text, matrix_client) -> dict:
|
||||
"""Execute a cron job and return a result dict for reporting."""
|
||||
job_type = job["jobType"]
|
||||
executor = EXECUTORS.get(job_type)
|
||||
|
||||
if not executor:
|
||||
msg = f"Unknown job type: {job_type}"
|
||||
logger.error(msg)
|
||||
return {"status": "error", "error": msg}
|
||||
|
||||
return await executor(job=job, send_text=send_text, matrix_client=matrix_client)
|
||||
56
cron/formatter.py
Normal file
56
cron/formatter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Format cron job results as Matrix messages (markdown)."""
|
||||
|
||||
|
||||
def format_search_results(job_name: str, results: list[dict]) -> str:
|
||||
"""Format Brave Search results as a markdown message for Matrix."""
|
||||
count = len(results)
|
||||
lines = [f"**{job_name}** \u2014 {count} new result{'s' if count != 1 else ''}:\n"]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
title = r.get("title", "Untitled")
|
||||
url = r.get("url", "")
|
||||
desc = r.get("description", "")
|
||||
lines.append(f"{i}. **[{title}]({url})**")
|
||||
if desc:
|
||||
lines.append(f" {desc}")
|
||||
lines.append("")
|
||||
|
||||
lines.append(
|
||||
"_[Manage automations](https://matrixhost.eu/settings/automations)_"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_listings(job_name: str, listings: list[dict]) -> str:
|
||||
"""Format browser-scraped listings as a markdown message for Matrix."""
|
||||
count = len(listings)
|
||||
lines = [f"**{job_name}** \u2014 {count} new listing{'s' if count != 1 else ''}:\n"]
|
||||
|
||||
for i, item in enumerate(listings, 1):
|
||||
title = item.get("title", "Unknown")
|
||||
price = item.get("price", "")
|
||||
location = item.get("location", "")
|
||||
url = item.get("url", "")
|
||||
age = item.get("age", "")
|
||||
|
||||
line = f"{i}. **{title}**"
|
||||
if price:
|
||||
line += f" \u2014 {price}"
|
||||
lines.append(line)
|
||||
|
||||
details = []
|
||||
if location:
|
||||
details.append(f"\U0001f4cd {location}")
|
||||
if age:
|
||||
details.append(f"\U0001f4c5 {age}")
|
||||
if url:
|
||||
details.append(f"[View listing]({url})")
|
||||
if details:
|
||||
sep = " \u00b7 "
|
||||
lines.append(f" {sep.join(details)}")
|
||||
lines.append("")
|
||||
|
||||
lines.append(
|
||||
"_[Manage automations](https://matrixhost.eu/settings/automations)_"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
20
cron/reminder.py
Normal file
20
cron/reminder.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Simple reminder executor for cron jobs."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def execute_reminder(job: dict, send_text, **_kwargs) -> dict:
|
||||
"""Post a reminder message to a Matrix room."""
|
||||
config = job.get("config", {})
|
||||
message = config.get("message", "")
|
||||
target_room = job["targetRoom"]
|
||||
|
||||
if not message:
|
||||
return {"status": "error", "error": "No reminder message configured"}
|
||||
|
||||
text = f"\u23f0 **{job['name']}:** {message}"
|
||||
await send_text(target_room, text)
|
||||
|
||||
return {"status": "success"}
|
||||
328
cron/scheduler.py
Normal file
328
cron/scheduler.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Cron job scheduler that syncs with matrixhost-web API and executes jobs."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import httpx
|
||||
|
||||
from .executor import execute_job
|
||||
from pipelines import PipelineEngine, PipelineStateManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_INTERVAL = 300 # 5 minutes — full job reconciliation
|
||||
PENDING_CHECK_INTERVAL = 15 # 15 seconds — fast check for manual triggers
|
||||
MAX_CONCURRENT_PER_USER = 3 # CF-2411: prevent runaway pipelines
|
||||
|
||||
|
||||
class CronScheduler:
|
||||
"""Fetches enabled cron jobs from the matrixhost portal and runs them on schedule."""
|
||||
|
||||
def __init__(self, portal_url: str, api_key: str, matrix_client, send_text_fn,
|
||||
llm_client=None, default_model: str = "claude-haiku",
|
||||
escalation_model: str = "claude-sonnet"):
|
||||
self.portal_url = portal_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.matrix_client = matrix_client
|
||||
self.send_text = send_text_fn
|
||||
self._jobs: dict[str, dict] = {} # id -> job data
|
||||
self._tasks: dict[str, asyncio.Task] = {} # id -> scheduler task
|
||||
self._running = False
|
||||
|
||||
# Pipeline engine
|
||||
self._pipeline_state = PipelineStateManager(portal_url, api_key)
|
||||
self.pipeline_engine = PipelineEngine(
|
||||
state=self._pipeline_state,
|
||||
send_text=send_text_fn,
|
||||
matrix_client=matrix_client,
|
||||
llm_client=llm_client,
|
||||
default_model=default_model,
|
||||
escalation_model=escalation_model,
|
||||
)
|
||||
self._pipelines: dict[str, dict] = {} # id -> pipeline data
|
||||
self._pipeline_tasks: dict[str, asyncio.Task] = {} # id -> scheduler task
|
||||
self._running_jobs: set[str] = set() # job IDs currently executing
|
||||
|
||||
async def start(self):
|
||||
"""Start the scheduler background loops."""
|
||||
self._running = True
|
||||
logger.info("Cron scheduler starting")
|
||||
await asyncio.sleep(15) # wait for bot to stabilize
|
||||
# Run full sync + fast pending check in parallel
|
||||
await asyncio.gather(
|
||||
self._full_sync_loop(),
|
||||
self._pending_check_loop(),
|
||||
self._pipeline_sync_loop(),
|
||||
self._pipeline_pending_check_loop(),
|
||||
)
|
||||
|
||||
async def _full_sync_loop(self):
|
||||
"""Full job reconciliation every 5 minutes."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._sync_jobs()
|
||||
except Exception:
|
||||
logger.warning("Cron job sync failed", exc_info=True)
|
||||
await asyncio.sleep(SYNC_INTERVAL)
|
||||
|
||||
async def _pending_check_loop(self):
|
||||
"""Fast poll for manual triggers every 15 seconds."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._check_pending()
|
||||
except Exception:
|
||||
logger.debug("Pending check failed", exc_info=True)
|
||||
await asyncio.sleep(PENDING_CHECK_INTERVAL)
|
||||
|
||||
async def _check_pending(self):
|
||||
"""Quick check for jobs with lastStatus='pending' and run them."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.portal_url}/api/cron/jobs/active",
|
||||
headers={"x-api-key": self.api_key},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return
|
||||
data = resp.json()
|
||||
|
||||
for job in data.get("jobs", []):
|
||||
if job.get("lastStatus") == "pending" and job["id"] not in self._running_jobs:
|
||||
logger.info("Pending trigger: %s", job["name"])
|
||||
self._running_jobs.add(job["id"])
|
||||
asyncio.create_task(self._run_once(job))
|
||||
|
||||
async def _pipeline_sync_loop(self):
|
||||
"""Full pipeline reconciliation every 5 minutes."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._sync_pipelines()
|
||||
except Exception:
|
||||
logger.warning("Pipeline sync failed", exc_info=True)
|
||||
await asyncio.sleep(SYNC_INTERVAL)
|
||||
|
||||
async def _pipeline_pending_check_loop(self):
|
||||
"""Fast poll for manually triggered pipelines every 15 seconds."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._check_pending_pipelines()
|
||||
except Exception:
|
||||
logger.debug("Pipeline pending check failed", exc_info=True)
|
||||
await asyncio.sleep(PENDING_CHECK_INTERVAL)
|
||||
|
||||
async def _sync_pipelines(self):
|
||||
"""Fetch active pipelines from portal and reconcile."""
|
||||
pipelines = await self._pipeline_state.fetch_active_pipelines()
|
||||
remote = {p["id"]: p for p in pipelines}
|
||||
|
||||
# Remove pipelines no longer active
|
||||
for pid in list(self._pipeline_tasks):
|
||||
if pid not in remote:
|
||||
logger.info("Removing pipeline %s (no longer active)", pid)
|
||||
self._pipeline_tasks[pid].cancel()
|
||||
del self._pipeline_tasks[pid]
|
||||
self._pipelines.pop(pid, None)
|
||||
|
||||
# Add/update cron-triggered pipelines
|
||||
for pid, pipeline in remote.items():
|
||||
existing = self._pipelines.get(pid)
|
||||
if existing and existing.get("updatedAt") == pipeline.get("updatedAt"):
|
||||
continue
|
||||
|
||||
if pid in self._pipeline_tasks:
|
||||
self._pipeline_tasks[pid].cancel()
|
||||
|
||||
self._pipelines[pid] = pipeline
|
||||
|
||||
if pipeline.get("triggerType") == "cron":
|
||||
self._pipeline_tasks[pid] = asyncio.create_task(
|
||||
self._pipeline_cron_loop(pipeline), name=f"pipeline-{pid}"
|
||||
)
|
||||
logger.info("Scheduled pipeline: %s (%s @ %s)",
|
||||
pipeline["name"], pipeline.get("schedule", ""), pipeline.get("scheduleAt", ""))
|
||||
|
||||
async def _check_pending_pipelines(self):
|
||||
"""Check for pipelines with lastStatus='pending' and run them."""
|
||||
pipelines = await self._pipeline_state.fetch_active_pipelines()
|
||||
for pipeline in pipelines:
|
||||
if pipeline.get("lastStatus") == "pending":
|
||||
# CF-2411: concurrent limit check
|
||||
user_id = pipeline.get("userId", "")
|
||||
if user_id:
|
||||
active = await self._pipeline_state.count_active_executions(user_id)
|
||||
if active >= MAX_CONCURRENT_PER_USER:
|
||||
logger.warning(
|
||||
"Pipeline %s skipped: user %s has %d active executions (limit %d)",
|
||||
pipeline["name"], user_id, active, MAX_CONCURRENT_PER_USER,
|
||||
)
|
||||
continue
|
||||
logger.info("Pending pipeline trigger: %s", pipeline["name"])
|
||||
asyncio.create_task(self.pipeline_engine.run(pipeline))
|
||||
|
||||
async def _pipeline_cron_loop(self, pipeline: dict):
|
||||
"""Run a pipeline on its cron schedule."""
|
||||
try:
|
||||
while True:
|
||||
sleep_secs = self._seconds_until_next_run(pipeline)
|
||||
if sleep_secs > 0:
|
||||
await asyncio.sleep(sleep_secs)
|
||||
# CF-2411: concurrent limit check
|
||||
user_id = pipeline.get("userId", "")
|
||||
if user_id:
|
||||
active = await self._pipeline_state.count_active_executions(user_id)
|
||||
if active >= MAX_CONCURRENT_PER_USER:
|
||||
logger.warning(
|
||||
"Pipeline %s cron skipped: user %s at limit (%d/%d)",
|
||||
pipeline["name"], user_id, active, MAX_CONCURRENT_PER_USER,
|
||||
)
|
||||
continue
|
||||
await self.pipeline_engine.run(pipeline)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def get_file_upload_pipelines(self) -> list[dict]:
|
||||
"""Return all active file_upload-triggered pipelines."""
|
||||
return [p for p in self._pipelines.values() if p.get("triggerType") == "file_upload"]
|
||||
|
||||
async def stop(self):
|
||||
self._running = False
|
||||
for task in self._tasks.values():
|
||||
task.cancel()
|
||||
self._tasks.clear()
|
||||
for task in self._pipeline_tasks.values():
|
||||
task.cancel()
|
||||
self._pipeline_tasks.clear()
|
||||
|
||||
async def _sync_jobs(self):
|
||||
"""Fetch active jobs from portal and reconcile with running tasks."""
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.portal_url}/api/cron/jobs/active",
|
||||
headers={"x-api-key": self.api_key},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
remote_jobs = {j["id"]: j for j in data.get("jobs", [])}
|
||||
|
||||
# Remove jobs that are no longer active
|
||||
for job_id in list(self._tasks):
|
||||
if job_id not in remote_jobs:
|
||||
logger.info("Removing cron job %s (no longer active)", job_id)
|
||||
self._tasks[job_id].cancel()
|
||||
del self._tasks[job_id]
|
||||
self._jobs.pop(job_id, None)
|
||||
|
||||
# Add/update jobs
|
||||
for job_id, job in remote_jobs.items():
|
||||
existing = self._jobs.get(job_id)
|
||||
if existing and existing.get("updatedAt") == job.get("updatedAt"):
|
||||
continue # unchanged
|
||||
|
||||
# Cancel old task if updating
|
||||
if job_id in self._tasks:
|
||||
self._tasks[job_id].cancel()
|
||||
|
||||
self._jobs[job_id] = job
|
||||
self._tasks[job_id] = asyncio.create_task(
|
||||
self._job_loop(job), name=f"cron-{job_id}"
|
||||
)
|
||||
logger.info("Scheduled cron job: %s (%s @ %s %s)",
|
||||
job["name"], job["schedule"], job.get("scheduleAt", ""), job["timezone"])
|
||||
|
||||
async def _job_loop(self, job: dict):
|
||||
"""Run a job on its schedule forever."""
|
||||
try:
|
||||
while True:
|
||||
sleep_secs = self._seconds_until_next_run(job)
|
||||
if sleep_secs > 0:
|
||||
await asyncio.sleep(sleep_secs)
|
||||
await self._run_once(job)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _run_once(self, job: dict):
|
||||
"""Execute a single job run and report results back."""
|
||||
job_id = job["id"]
|
||||
logger.info("Running cron job: %s (%s)", job["name"], job["jobType"])
|
||||
try:
|
||||
result = await execute_job(
|
||||
job=job,
|
||||
send_text=self.send_text,
|
||||
matrix_client=self.matrix_client,
|
||||
)
|
||||
await self._report_result(job_id, result)
|
||||
except Exception as exc:
|
||||
logger.error("Cron job %s failed: %s", job["name"], exc, exc_info=True)
|
||||
await self._report_result(job_id, {
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
})
|
||||
finally:
|
||||
self._running_jobs.discard(job_id)
|
||||
|
||||
async def _report_result(self, job_id: str, result: dict):
|
||||
"""Report job execution result back to the portal."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.post(
|
||||
f"{self.portal_url}/api/cron/jobs/{job_id}/result",
|
||||
headers={"x-api-key": self.api_key},
|
||||
json=result,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to report cron result for %s", job_id, exc_info=True)
|
||||
|
||||
def _seconds_until_next_run(self, job: dict) -> float:
|
||||
"""Calculate seconds until the next scheduled run."""
|
||||
import zoneinfo
|
||||
|
||||
schedule = job["schedule"]
|
||||
schedule_at = job.get("scheduleAt", "09:00") or "09:00"
|
||||
tz = zoneinfo.ZoneInfo(job.get("timezone", "Europe/Berlin"))
|
||||
now = datetime.now(tz)
|
||||
|
||||
hour, minute = (int(x) for x in schedule_at.split(":"))
|
||||
|
||||
if schedule == "hourly":
|
||||
# Run at the top of every hour
|
||||
next_run = now.replace(minute=0, second=0, microsecond=0)
|
||||
if next_run <= now:
|
||||
next_run = next_run.replace(hour=now.hour + 1)
|
||||
return (next_run - now).total_seconds()
|
||||
|
||||
if schedule == "daily":
|
||||
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
if next_run <= now:
|
||||
from datetime import timedelta
|
||||
next_run += timedelta(days=1)
|
||||
return (next_run - now).total_seconds()
|
||||
|
||||
if schedule == "weekly":
|
||||
# Monday = 0
|
||||
from datetime import timedelta
|
||||
days_ahead = (0 - now.weekday()) % 7 or 7
|
||||
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
if now.weekday() == 0 and next_run > now:
|
||||
days_ahead = 0
|
||||
next_run += timedelta(days=days_ahead)
|
||||
if next_run <= now:
|
||||
next_run += timedelta(days=7)
|
||||
return (next_run - now).total_seconds()
|
||||
|
||||
if schedule == "weekdays":
|
||||
from datetime import timedelta
|
||||
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
if next_run <= now:
|
||||
next_run += timedelta(days=1)
|
||||
# Skip weekends
|
||||
while next_run.weekday() >= 5:
|
||||
next_run += timedelta(days=1)
|
||||
return (next_run - now).total_seconds()
|
||||
|
||||
# Default: daily
|
||||
from datetime import timedelta
|
||||
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
if next_run <= now:
|
||||
next_run += timedelta(days=1)
|
||||
return (next_run - now).total_seconds()
|
||||
292
cross_signing.py
Normal file
292
cross_signing.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Cross-signing manager for Matrix bot accounts.
|
||||
|
||||
Handles bootstrapping, verification, and recovery of cross-signing keys.
|
||||
Reusable by any bot (matrix-ai-agent, claude-matrix-bridge, etc.).
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import canonicaljson
|
||||
import httpx
|
||||
import olm.pk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrossSigningManager:
|
||||
"""Manages cross-signing keys for a Matrix bot device."""
|
||||
|
||||
def __init__(self, homeserver: str, store_path: str, bot_password: str, vault_key: str = ""):
|
||||
"""
|
||||
Args:
|
||||
vault_key: Vault key for seed backup/recovery (e.g. "matrix.ai.cross_signing_seeds").
|
||||
If empty, vault backup is disabled.
|
||||
"""
|
||||
self.homeserver = homeserver
|
||||
self.store_path = store_path
|
||||
self.bot_password = bot_password
|
||||
self.vault_key = vault_key
|
||||
self._xsign_file = os.path.join(store_path, "cross_signing_keys.json")
|
||||
|
||||
async def ensure_cross_signing(self, user_id: str, device_id: str, access_token: str) -> bool:
|
||||
"""Ensure device is cross-signed. Returns True if already signed or newly signed."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
# Check if device already has a cross-signing signature
|
||||
if await self._is_device_cross_signed(user_id, device_id, headers):
|
||||
logger.info("Device %s already cross-signed, skipping bootstrap", device_id)
|
||||
return True
|
||||
|
||||
# Load existing seeds or generate new ones
|
||||
master_seed, ss_seed, us_seed = self._load_or_generate_seeds()
|
||||
|
||||
master_key = olm.pk.PkSigning(master_seed)
|
||||
self_signing_key = olm.pk.PkSigning(ss_seed)
|
||||
user_signing_key = olm.pk.PkSigning(us_seed)
|
||||
|
||||
# Upload cross-signing keys and sign device
|
||||
success = await self._upload_and_sign(
|
||||
user_id, device_id, access_token,
|
||||
master_key, self_signing_key, user_signing_key,
|
||||
)
|
||||
|
||||
if success:
|
||||
self._save_seeds(master_seed, ss_seed, us_seed)
|
||||
# Verify the signature was applied
|
||||
if await self._is_device_cross_signed(user_id, device_id, headers):
|
||||
logger.info("Device %s cross-signed and verified successfully", device_id)
|
||||
return True
|
||||
else:
|
||||
logger.error("Device %s: signature upload succeeded but verification failed — retrying once", device_id)
|
||||
# Retry signing the device (keys already uploaded)
|
||||
retry = await self._sign_device(user_id, device_id, access_token, self_signing_key)
|
||||
if retry and await self._is_device_cross_signed(user_id, device_id, headers):
|
||||
logger.info("Device %s cross-signed on retry", device_id)
|
||||
return True
|
||||
logger.error("Device %s cross-signing failed after retry", device_id)
|
||||
|
||||
return False
|
||||
|
||||
async def verify_cross_signing_status(self, user_id: str, device_id: str, access_token: str) -> dict:
|
||||
"""Check cross-signing status. Returns dict with status details."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
is_signed = await self._is_device_cross_signed(user_id, device_id, headers)
|
||||
has_seeds = os.path.exists(self._xsign_file)
|
||||
return {
|
||||
"device_id": device_id,
|
||||
"cross_signed": is_signed,
|
||||
"seeds_stored": has_seeds,
|
||||
}
|
||||
|
||||
def get_public_keys(self) -> dict | None:
|
||||
"""Return public key fingerprints (safe to log). Returns None if no seeds."""
|
||||
if not os.path.exists(self._xsign_file):
|
||||
return None
|
||||
master_seed, ss_seed, us_seed = self._load_or_generate_seeds()
|
||||
return {
|
||||
"master": olm.pk.PkSigning(master_seed).public_key,
|
||||
"self_signing": olm.pk.PkSigning(ss_seed).public_key,
|
||||
"user_signing": olm.pk.PkSigning(us_seed).public_key,
|
||||
}
|
||||
|
||||
# -- Private methods --
|
||||
|
||||
async def _is_device_cross_signed(self, user_id: str, device_id: str, headers: dict) -> bool:
|
||||
"""Query server to check if device has a cross-signing signature."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as hc:
|
||||
resp = await hc.post(
|
||||
f"{self.homeserver}/_matrix/client/v3/keys/query",
|
||||
json={"device_keys": {user_id: [device_id]}},
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return False
|
||||
device_keys = resp.json().get("device_keys", {}).get(user_id, {}).get(device_id, {})
|
||||
sigs = device_keys.get("signatures", {}).get(user_id, {})
|
||||
device_key_id = f"ed25519:{device_id}"
|
||||
return any(k != device_key_id for k in sigs)
|
||||
except Exception as e:
|
||||
logger.warning("Cross-signing check failed: %s", e)
|
||||
return False
|
||||
|
||||
def _load_or_generate_seeds(self) -> tuple[bytes, bytes, bytes]:
|
||||
"""Load seeds from file, vault backup, or generate new ones."""
|
||||
# 1. Try local file first
|
||||
if os.path.exists(self._xsign_file):
|
||||
with open(self._xsign_file) as f:
|
||||
seeds = json.load(f)
|
||||
logger.info("Loaded cross-signing seeds from local store")
|
||||
return (
|
||||
base64.b64decode(seeds["master_seed"]),
|
||||
base64.b64decode(seeds["self_signing_seed"]),
|
||||
base64.b64decode(seeds["user_signing_seed"]),
|
||||
)
|
||||
|
||||
# 2. Try vault recovery
|
||||
if self.vault_key:
|
||||
recovered = self._vault_recover_seeds()
|
||||
if recovered:
|
||||
logger.info("Recovered cross-signing seeds from vault")
|
||||
return recovered
|
||||
|
||||
# 3. Generate new
|
||||
logger.info("Generating new cross-signing keys")
|
||||
return os.urandom(32), os.urandom(32), os.urandom(32)
|
||||
|
||||
def _save_seeds(self, master: bytes, ss: bytes, us: bytes) -> None:
|
||||
"""Persist seeds to local file and vault."""
|
||||
if os.path.exists(self._xsign_file):
|
||||
# File exists; still try vault backup if not yet stored
|
||||
self._vault_backup_seeds(master, ss, us)
|
||||
return
|
||||
os.makedirs(os.path.dirname(self._xsign_file), exist_ok=True)
|
||||
with open(self._xsign_file, "w") as f:
|
||||
json.dump({
|
||||
"master_seed": base64.b64encode(master).decode(),
|
||||
"self_signing_seed": base64.b64encode(ss).decode(),
|
||||
"user_signing_seed": base64.b64encode(us).decode(),
|
||||
}, f)
|
||||
logger.info("Cross-signing seeds saved to %s", self._xsign_file)
|
||||
self._vault_backup_seeds(master, ss, us)
|
||||
|
||||
def _vault_backup_seeds(self, master: bytes, ss: bytes, us: bytes) -> None:
|
||||
"""Backup seeds to vault for disaster recovery."""
|
||||
if not self.vault_key:
|
||||
return
|
||||
payload = json.dumps({
|
||||
"master_seed": base64.b64encode(master).decode(),
|
||||
"self_signing_seed": base64.b64encode(ss).decode(),
|
||||
"user_signing_seed": base64.b64encode(us).decode(),
|
||||
})
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["vault", "set", self.vault_key],
|
||||
input=payload, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info("Cross-signing seeds backed up to vault key %s", self.vault_key)
|
||||
else:
|
||||
logger.warning("Vault backup failed: %s", result.stderr.strip())
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired) as e:
|
||||
logger.warning("Vault backup skipped: %s", e)
|
||||
|
||||
def _vault_recover_seeds(self) -> tuple[bytes, bytes, bytes] | None:
|
||||
"""Attempt to recover seeds from vault."""
|
||||
if not self.vault_key:
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["vault", "get", self.vault_key],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if result.returncode != 0 or not result.stdout.strip():
|
||||
return None
|
||||
seeds = json.loads(result.stdout.strip())
|
||||
return (
|
||||
base64.b64decode(seeds["master_seed"]),
|
||||
base64.b64decode(seeds["self_signing_seed"]),
|
||||
base64.b64decode(seeds["user_signing_seed"]),
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning("Vault recovery failed: %s", e)
|
||||
return None
|
||||
|
||||
async def _upload_and_sign(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
access_token: str,
|
||||
master_key: olm.pk.PkSigning,
|
||||
self_signing_key: olm.pk.PkSigning,
|
||||
user_signing_key: olm.pk.PkSigning,
|
||||
) -> bool:
|
||||
"""Upload cross-signing keys and sign device."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
def make_key(usage: str, pubkey: str) -> dict:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"usage": [usage],
|
||||
"keys": {"ed25519:" + pubkey: pubkey},
|
||||
}
|
||||
|
||||
master_obj = make_key("master", master_key.public_key)
|
||||
ss_obj = make_key("self_signing", self_signing_key.public_key)
|
||||
us_obj = make_key("user_signing", user_signing_key.public_key)
|
||||
|
||||
# Sign sub-keys with master key
|
||||
ss_canonical = canonicaljson.encode_canonical_json(ss_obj)
|
||||
ss_obj["signatures"] = {user_id: {"ed25519:" + master_key.public_key: master_key.sign(ss_canonical)}}
|
||||
|
||||
us_canonical = canonicaljson.encode_canonical_json(us_obj)
|
||||
us_obj["signatures"] = {user_id: {"ed25519:" + master_key.public_key: master_key.sign(us_canonical)}}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as hc:
|
||||
# Upload cross-signing keys with password auth
|
||||
resp = await hc.post(
|
||||
f"{self.homeserver}/_matrix/client/v3/keys/device_signing/upload",
|
||||
json={
|
||||
"master_key": master_obj,
|
||||
"self_signing_key": ss_obj,
|
||||
"user_signing_key": us_obj,
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": user_id},
|
||||
"password": self.bot_password,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("Cross-signing upload failed: %d %s", resp.status_code, resp.text)
|
||||
return False
|
||||
logger.info("Cross-signing keys uploaded (master: %s)", master_key.public_key)
|
||||
|
||||
# Sign the device
|
||||
return await self._sign_device(user_id, device_id, access_token, self_signing_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Cross-signing bootstrap failed: %s", e, exc_info=True)
|
||||
return False
|
||||
|
||||
async def _sign_device(
|
||||
self, user_id: str, device_id: str, access_token: str, self_signing_key: olm.pk.PkSigning,
|
||||
) -> bool:
|
||||
"""Sign a device with the self-signing key."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as hc:
|
||||
qresp = await hc.post(
|
||||
f"{self.homeserver}/_matrix/client/v3/keys/query",
|
||||
json={"device_keys": {user_id: [device_id]}},
|
||||
headers=headers,
|
||||
)
|
||||
device_obj = qresp.json()["device_keys"][user_id][device_id]
|
||||
device_obj.pop("signatures", None)
|
||||
device_obj.pop("unsigned", None)
|
||||
|
||||
dk_canonical = canonicaljson.encode_canonical_json(device_obj)
|
||||
dk_sig = self_signing_key.sign(dk_canonical)
|
||||
device_obj["signatures"] = {
|
||||
user_id: {"ed25519:" + self_signing_key.public_key: dk_sig},
|
||||
}
|
||||
|
||||
resp = await hc.post(
|
||||
f"{self.homeserver}/_matrix/client/v3/keys/signatures/upload",
|
||||
json={user_id: {device_id: device_obj}},
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("Device signature upload failed: %d %s", resp.status_code, resp.text)
|
||||
return False
|
||||
logger.info("Device %s signed with self-signing key", device_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Device signing failed: %s", e, exc_info=True)
|
||||
return False
|
||||
34
device_trust.py
Normal file
34
device_trust.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Device trust policy: only trust cross-signed devices.
|
||||
|
||||
Replaces the insecure auto-trust-all pattern with selective verification
|
||||
based on cross-signing signatures.
|
||||
"""
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrossSignedOnlyPolicy:
|
||||
"""Trust only devices that carry a cross-signing signature.
|
||||
|
||||
A device's signatures dict typically contains its own ed25519:DEVICE_ID
|
||||
self-signature. A cross-signed device additionally has a signature from
|
||||
the user's self-signing key (ed25519:SELF_SIGNING_PUB). This policy
|
||||
checks for that extra signature.
|
||||
"""
|
||||
|
||||
def should_trust(self, user_id: str, device) -> bool:
|
||||
"""Return True if device has a cross-signing signature beyond its own."""
|
||||
sigs = getattr(device, "signatures", None)
|
||||
if not sigs:
|
||||
return False
|
||||
|
||||
user_sigs = sigs.get(user_id, {})
|
||||
device_self_key = f"ed25519:{device.device_id}"
|
||||
|
||||
# Trust if any signature key is NOT the device's own key
|
||||
for key_id in user_sigs:
|
||||
if key_id != device_self_key:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -19,11 +19,25 @@ services:
|
||||
- LITELLM_BASE_URL
|
||||
- LITELLM_API_KEY
|
||||
- DEFAULT_MODEL
|
||||
- WILDFILES_BASE_URL
|
||||
- WILDFILES_ORG
|
||||
- BASE_MODEL=${BASE_MODEL:-claude-haiku}
|
||||
- ESCALATION_MODEL=${ESCALATION_MODEL:-claude-sonnet}
|
||||
- MEMORY_SERVICE_URL=http://memory-service:8090
|
||||
- MEMORY_SERVICE_TOKEN
|
||||
- PORTAL_URL
|
||||
- BOT_API_KEY
|
||||
- SKYVERN_BASE_URL=http://skyvern:8000
|
||||
- SKYVERN_API_KEY
|
||||
ports:
|
||||
- "9100:9100"
|
||||
volumes:
|
||||
- bot-data:/data
|
||||
# Mount source files so git pull + restart works without rebuild
|
||||
- ./bot.py:/app/bot.py:ro
|
||||
- ./voice.py:/app/voice.py:ro
|
||||
- ./agent.py:/app/agent.py:ro
|
||||
- ./e2ee_patch.py:/app/e2ee_patch.py:ro
|
||||
- ./cross_signing.py:/app/cross_signing.py:ro
|
||||
- ./device_trust.py:/app/device_trust.py:ro
|
||||
depends_on:
|
||||
memory-service:
|
||||
condition: service_healthy
|
||||
@@ -37,6 +51,13 @@ services:
|
||||
POSTGRES_DB: memories
|
||||
volumes:
|
||||
- memory-pgdata:/var/lib/postgresql/data
|
||||
- ./memory-db-ssl/server.crt:/var/lib/postgresql/server.crt:ro
|
||||
- ./memory-db-ssl/server.key:/var/lib/postgresql/server.key:ro
|
||||
command: >
|
||||
postgres
|
||||
-c ssl=on
|
||||
-c ssl_cert_file=/var/lib/postgresql/server.crt
|
||||
-c ssl_key_file=/var/lib/postgresql/server.key
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U memory -d memories"]
|
||||
interval: 5s
|
||||
@@ -47,10 +68,13 @@ services:
|
||||
build: ./memory-service
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
DATABASE_URL: postgresql://memory:${MEMORY_DB_PASSWORD:-memory}@memory-db:5432/memories
|
||||
DATABASE_URL: postgresql://memory_app:${MEMORY_APP_PASSWORD}@memory-db:5432/memories?sslmode=require
|
||||
MEMORY_ENCRYPTION_KEY: ${MEMORY_ENCRYPTION_KEY}
|
||||
MEMORY_DB_OWNER_PASSWORD: ${MEMORY_DB_PASSWORD:-memory}
|
||||
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
||||
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
|
||||
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
||||
MEMORY_SERVICE_TOKEN: ${MEMORY_SERVICE_TOKEN:-}
|
||||
depends_on:
|
||||
memory-db:
|
||||
condition: service_healthy
|
||||
@@ -60,6 +84,54 @@ services:
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
skyvern:
|
||||
image: public.ecr.aws/skyvern/skyvern:latest
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
DATABASE_STRING: postgresql+psycopg://skyvern:${SKYVERN_DB_PASSWORD:-skyvern}@skyvern-db:5432/skyvern
|
||||
ENABLE_OPENAI_COMPATIBLE: "true"
|
||||
OPENAI_COMPATIBLE_API_KEY: ${LITELLM_API_KEY}
|
||||
OPENAI_COMPATIBLE_API_BASE: ${LITELLM_BASE_URL}
|
||||
OPENAI_COMPATIBLE_MODEL_NAME: gpt-4o
|
||||
OPENAI_COMPATIBLE_SUPPORTS_VISION: "true"
|
||||
LLM_KEY: OPENAI_COMPATIBLE
|
||||
SECONDARY_LLM_KEY: OPENAI_COMPATIBLE
|
||||
BROWSER_TYPE: chromium-headful
|
||||
ENABLE_CODE_BLOCK: "true"
|
||||
ENV: local
|
||||
PORT: "8000"
|
||||
ALLOWED_ORIGINS: '["http://localhost:8000"]'
|
||||
volumes:
|
||||
- skyvern-artifacts:/data/artifacts
|
||||
- skyvern-videos:/data/videos
|
||||
depends_on:
|
||||
skyvern-db:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/api/v1/heartbeat')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
skyvern-db:
|
||||
image: postgres:14-alpine
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: skyvern
|
||||
POSTGRES_PASSWORD: ${SKYVERN_DB_PASSWORD:-skyvern}
|
||||
POSTGRES_DB: skyvern
|
||||
volumes:
|
||||
- skyvern-pgdata:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U skyvern -d skyvern"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
bot-data:
|
||||
memory-pgdata:
|
||||
skyvern-pgdata:
|
||||
skyvern-artifacts:
|
||||
skyvern-videos:
|
||||
|
||||
37
hkdf_fix.py
Normal file
37
hkdf_fix.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Patch Rust HKDF key derivation to output 16 bytes (AES-128).
|
||||
|
||||
Element Call JS SDK derives 128-bit AES-GCM keys via:
|
||||
deriveKey({name:"AES-GCM", length:128}, ...)
|
||||
|
||||
The C++ FrameCryptor may allocate a larger derived_key buffer (32+ bytes).
|
||||
This patch ensures only 16 bytes are filled, matching the JS output.
|
||||
"""
|
||||
import sys
|
||||
|
||||
path = sys.argv[1]
|
||||
with open(path) as f:
|
||||
content = f.read()
|
||||
|
||||
old = 'hkdf.expand(&[0u8; 128], derived_key).is_ok()'
|
||||
new = """{
|
||||
// MAT-258: Derive 16 bytes (AES-128) matching Element Call JS SDK
|
||||
let mut buf = [0u8; 16];
|
||||
let ok = hkdf.expand(&[0u8; 128], &mut buf).is_ok();
|
||||
if ok {
|
||||
// Fill first 16 bytes of derived_key, zero-pad rest
|
||||
let len = derived_key.len().min(16);
|
||||
derived_key[..len].copy_from_slice(&buf[..len]);
|
||||
for b in derived_key[len..].iter_mut() { *b = 0; }
|
||||
}
|
||||
ok
|
||||
}"""
|
||||
|
||||
if old not in content:
|
||||
print(f"WARNING: Could not find HKDF expand line in {path}")
|
||||
print("File may already be patched or format changed")
|
||||
sys.exit(0)
|
||||
|
||||
content = content.replace(old, new)
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print(f"Patched HKDF output to 16 bytes in {path}")
|
||||
@@ -2,5 +2,5 @@ FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY main.py .
|
||||
COPY main.py migrate_encrypt.py ./
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||
|
||||
@@ -1,24 +1,103 @@
|
||||
import os
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
logger = logging.getLogger("memory-service")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
|
||||
OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format(
|
||||
pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
|
||||
))
|
||||
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
||||
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
||||
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
||||
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||
|
||||
app = FastAPI(title="Memory Service")
|
||||
pool: asyncpg.Pool | None = None
|
||||
owner_pool: asyncpg.Pool | None = None
|
||||
_pool_healthy = True
|
||||
|
||||
|
||||
async def verify_token(authorization: str | None = Header(None)):
|
||||
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
|
||||
if not MEMORY_SERVICE_TOKEN:
|
||||
return
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, "Missing or invalid Authorization header")
|
||||
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
|
||||
raise HTTPException(403, "Invalid token")
|
||||
|
||||
|
||||
def _derive_user_key(user_id: str) -> bytes:
|
||||
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
||||
if not ENCRYPTION_KEY:
|
||||
raise RuntimeError("MEMORY_ENCRYPTION_KEY not set")
|
||||
derived = hashlib.pbkdf2_hmac(
|
||||
"sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1
|
||||
)
|
||||
return base64.urlsafe_b64encode(derived)
|
||||
|
||||
|
||||
def _encrypt(text: str, user_id: str) -> str:
|
||||
"""Encrypt text with per-user Fernet key. Returns base64 ciphertext."""
|
||||
if not ENCRYPTION_KEY:
|
||||
return text
|
||||
f = Fernet(_derive_user_key(user_id))
|
||||
return f.encrypt(text.encode()).decode()
|
||||
|
||||
|
||||
def _decrypt(ciphertext: str, user_id: str) -> str:
|
||||
"""Decrypt ciphertext with per-user Fernet key."""
|
||||
if not ENCRYPTION_KEY:
|
||||
return ciphertext
|
||||
try:
|
||||
f = Fernet(_derive_user_key(user_id))
|
||||
return f.decrypt(ciphertext.encode()).decode()
|
||||
except Exception:
|
||||
# Plaintext fallback for not-yet-migrated rows
|
||||
return ciphertext
|
||||
|
||||
|
||||
class ScheduleRequest(BaseModel):
|
||||
user_id: str
|
||||
room_id: str
|
||||
message_text: str
|
||||
scheduled_at: float # Unix timestamp
|
||||
repeat_pattern: str = "once" # once | daily | weekly | weekdays | monthly
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
@field_validator('repeat_pattern')
|
||||
@classmethod
|
||||
def valid_pattern(cls, v):
|
||||
allowed = {"once", "daily", "weekly", "weekdays", "monthly"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"repeat_pattern must be one of {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class ScheduleCancelRequest(BaseModel):
|
||||
id: int
|
||||
user_id: str
|
||||
|
||||
|
||||
class StoreRequest(BaseModel):
|
||||
@@ -26,12 +105,60 @@ class StoreRequest(BaseModel):
|
||||
fact: str
|
||||
source_room: str = ""
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
user_id: str
|
||||
query: str
|
||||
top_k: int = 10
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkStoreRequest(BaseModel):
|
||||
user_id: str
|
||||
room_id: str
|
||||
chunk_text: str
|
||||
summary: str
|
||||
source_event_id: str = ""
|
||||
original_ts: float = 0.0
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
user_id: str # REQUIRED — no default
|
||||
room_id: str = ""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def user_id_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("user_id is required")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChunkBulkStoreRequest(BaseModel):
|
||||
chunks: list[ChunkStoreRequest]
|
||||
|
||||
|
||||
async def _embed(text: str) -> list[float]:
|
||||
"""Get embedding vector from LiteLLM /embeddings endpoint."""
|
||||
@@ -45,11 +172,58 @@ async def _embed(text: str) -> list[float]:
|
||||
return resp.json()["data"][0]["embedding"]
|
||||
|
||||
|
||||
async def _embed_batch(texts: list[str]) -> list[list[float]]:
|
||||
"""Get embedding vectors for a batch of texts."""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LITELLM_URL}/embeddings",
|
||||
json={"model": EMBED_MODEL, "input": texts},
|
||||
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()["data"]
|
||||
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
|
||||
|
||||
|
||||
async def _set_rls_user(conn, user_id: str):
|
||||
"""Set the RLS session variable for the current connection."""
|
||||
await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id)
|
||||
|
||||
|
||||
async def _ensure_pool():
|
||||
"""Recreate the connection pool if it was lost."""
|
||||
global pool, owner_pool, _pool_healthy
|
||||
if pool and _pool_healthy:
|
||||
return
|
||||
logger.warning("Reconnecting asyncpg pools (healthy=%s, pool=%s)", _pool_healthy, pool is not None)
|
||||
try:
|
||||
if pool:
|
||||
try:
|
||||
await pool.close()
|
||||
except Exception:
|
||||
pass
|
||||
if owner_pool:
|
||||
try:
|
||||
await owner_pool.close()
|
||||
except Exception:
|
||||
pass
|
||||
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
||||
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
||||
_pool_healthy = True
|
||||
logger.info("asyncpg pools reconnected successfully")
|
||||
except Exception:
|
||||
_pool_healthy = False
|
||||
logger.exception("Failed to reconnect asyncpg pools")
|
||||
raise
|
||||
|
||||
|
||||
async def _init_db():
|
||||
"""Create pgvector extension and memories table if not exists."""
|
||||
global pool
|
||||
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
||||
async with pool.acquire() as conn:
|
||||
global pool, _pool_healthy
|
||||
# Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool
|
||||
owner_conn = await asyncpg.connect(OWNER_DSN)
|
||||
conn = owner_conn
|
||||
try:
|
||||
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
await conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
@@ -67,9 +241,60 @@ async def _init_db():
|
||||
await conn.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_embedding
|
||||
ON memories USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 10)
|
||||
WITH (lists = 100)
|
||||
""")
|
||||
logger.info("Database initialized (dims=%d)", EMBED_DIMS)
|
||||
# Conversation chunks table for RAG over chat history
|
||||
await conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS conversation_chunks (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
chunk_text TEXT NOT NULL,
|
||||
summary TEXT NOT NULL,
|
||||
source_event_id TEXT DEFAULT '',
|
||||
original_ts DOUBLE PRECISION NOT NULL,
|
||||
embedding vector({EMBED_DIMS}) NOT NULL
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
|
||||
""")
|
||||
# Scheduled messages table for reminders
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS scheduled_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
message_text TEXT NOT NULL,
|
||||
scheduled_at DOUBLE PRECISION NOT NULL,
|
||||
created_at DOUBLE PRECISION NOT NULL,
|
||||
status TEXT DEFAULT 'pending',
|
||||
repeat_pattern TEXT DEFAULT 'once',
|
||||
repeat_interval_seconds INTEGER DEFAULT 0,
|
||||
last_sent_at DOUBLE PRECISION DEFAULT 0
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_scheduled_user_id ON scheduled_messages (user_id)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_scheduled_status ON scheduled_messages (status, scheduled_at)
|
||||
""")
|
||||
finally:
|
||||
await owner_conn.close()
|
||||
# Create restricted pool for all request handlers (RLS applies)
|
||||
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
||||
# Owner pool for admin queries (bypasses RLS) — 1 connection only
|
||||
global owner_pool
|
||||
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
||||
_pool_healthy = True
|
||||
logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -81,27 +306,45 @@ async def startup():
|
||||
async def shutdown():
|
||||
if pool:
|
||||
await pool.close()
|
||||
if owner_pool:
|
||||
await owner_pool.close()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||
return {"status": "ok", "total_memories": count}
|
||||
return {"status": "no_db"}
|
||||
global _pool_healthy
|
||||
try:
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
||||
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
||||
sched_count = await conn.fetchval("SELECT count(*) FROM scheduled_messages WHERE status = 'pending'")
|
||||
return {
|
||||
"status": "ok",
|
||||
"total_memories": mem_count,
|
||||
"total_chunks": chunk_count,
|
||||
"pending_reminders": sched_count,
|
||||
"encryption": "on" if ENCRYPTION_KEY else "off",
|
||||
}
|
||||
except Exception as e:
|
||||
_pool_healthy = False
|
||||
logger.error("Health check failed: %s", e)
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
|
||||
@app.post("/memories/store")
|
||||
async def store_memory(req: StoreRequest):
|
||||
"""Embed fact, deduplicate by cosine similarity, insert."""
|
||||
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
|
||||
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
||||
if not req.fact.strip():
|
||||
raise HTTPException(400, "Empty fact")
|
||||
|
||||
embedding = await _embed(req.fact)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
# Check for duplicates (cosine similarity > threshold)
|
||||
dup = await conn.fetchval(
|
||||
"""
|
||||
@@ -116,24 +359,28 @@ async def store_memory(req: StoreRequest):
|
||||
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
|
||||
return {"stored": False, "reason": "duplicate"}
|
||||
|
||||
encrypted_fact = _encrypt(req.fact.strip(), req.user_id)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5::vector)
|
||||
""",
|
||||
req.user_id, req.fact.strip(), req.source_room, time.time(), vec_literal,
|
||||
req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal,
|
||||
)
|
||||
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
|
||||
return {"stored": True}
|
||||
|
||||
|
||||
@app.post("/memories/query")
|
||||
async def query_memories(req: QueryRequest):
|
||||
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
||||
"""Embed query, return top-K similar facts for user."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT fact, source_room, created_at,
|
||||
@@ -148,7 +395,7 @@ async def query_memories(req: QueryRequest):
|
||||
|
||||
results = [
|
||||
{
|
||||
"fact": r["fact"],
|
||||
"fact": _decrypt(r["fact"], req.user_id),
|
||||
"source_room": r["source_room"],
|
||||
"created_at": r["created_at"],
|
||||
"similarity": float(r["similarity"]),
|
||||
@@ -159,9 +406,11 @@ async def query_memories(req: QueryRequest):
|
||||
|
||||
|
||||
@app.delete("/memories/{user_id}")
|
||||
async def delete_user_memories(user_id: str):
|
||||
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""GDPR delete — remove all memories for a user."""
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
|
||||
count = int(result.split()[-1])
|
||||
logger.info("Deleted %d memories for %s", count, user_id)
|
||||
@@ -169,9 +418,11 @@ async def delete_user_memories(user_id: str):
|
||||
|
||||
|
||||
@app.get("/memories/{user_id}")
|
||||
async def list_user_memories(user_id: str):
|
||||
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||
"""List all memories for a user (for UI/debug)."""
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT fact, source_room, created_at
|
||||
@@ -185,7 +436,330 @@ async def list_user_memories(user_id: str):
|
||||
"user_id": user_id,
|
||||
"count": len(rows),
|
||||
"memories": [
|
||||
{"fact": r["fact"], "source_room": r["source_room"], "created_at": r["created_at"]}
|
||||
{"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# --- Conversation Chunks ---
|
||||
|
||||
|
||||
@app.post("/chunks/store")
|
||||
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
|
||||
"""Store a conversation chunk with its summary embedding, encrypted."""
|
||||
if not req.summary.strip():
|
||||
raise HTTPException(400, "Empty summary")
|
||||
|
||||
embedding = await _embed(req.summary)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
ts = req.original_ts or time.time()
|
||||
|
||||
encrypted_text = _encrypt(req.chunk_text, req.user_id)
|
||||
encrypted_summary = _encrypt(req.summary, req.user_id)
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO conversation_chunks
|
||||
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||
""",
|
||||
req.user_id, req.room_id, encrypted_text, encrypted_summary,
|
||||
req.source_event_id, ts, vec_literal,
|
||||
)
|
||||
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
|
||||
return {"stored": True}
|
||||
|
||||
|
||||
@app.post("/chunks/query")
|
||||
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
|
||||
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
||||
embedding = await _embed(req.query)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
|
||||
# Build WHERE clause — user_id is always required
|
||||
conditions = [f"user_id = $2"]
|
||||
params: list = [vec_literal, req.user_id]
|
||||
idx = 3
|
||||
|
||||
if req.room_id:
|
||||
conditions.append(f"room_id = ${idx}")
|
||||
params.append(req.room_id)
|
||||
idx += 1
|
||||
|
||||
where = f"WHERE {' AND '.join(conditions)}"
|
||||
params.append(req.top_k)
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, req.user_id)
|
||||
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
|
||||
1 - (embedding <=> $1::vector) AS similarity
|
||||
FROM conversation_chunks
|
||||
{where}
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT ${idx}
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
results = [
|
||||
{
|
||||
"chunk_text": _decrypt(r["chunk_text"], r["user_id"]),
|
||||
"summary": _decrypt(r["summary"], r["user_id"]),
|
||||
"room_id": r["room_id"],
|
||||
"user_id": r["user_id"],
|
||||
"original_ts": r["original_ts"],
|
||||
"source_event_id": r["source_event_id"],
|
||||
"similarity": float(r["similarity"]),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return {"results": results}
|
||||
|
||||
|
||||
@app.post("/chunks/bulk-store")
|
||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
|
||||
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
||||
if not req.chunks:
|
||||
return {"stored": 0}
|
||||
|
||||
stored = 0
|
||||
batch_size = 20
|
||||
|
||||
for i in range(0, len(req.chunks), batch_size):
|
||||
batch = req.chunks[i:i + batch_size]
|
||||
summaries = [c.summary.strip() for c in batch]
|
||||
|
||||
try:
|
||||
embeddings = await _embed_batch(summaries)
|
||||
except Exception:
|
||||
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
|
||||
continue
|
||||
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
for chunk, embedding in zip(batch, embeddings):
|
||||
await _set_rls_user(conn, chunk.user_id)
|
||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
ts = chunk.original_ts or time.time()
|
||||
encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id)
|
||||
encrypted_summary = _encrypt(chunk.summary, chunk.user_id)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO conversation_chunks
|
||||
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
||||
""",
|
||||
chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary,
|
||||
chunk.source_event_id, ts, vec_literal,
|
||||
)
|
||||
stored += 1
|
||||
|
||||
logger.info("Bulk stored %d chunks", stored)
|
||||
return {"stored": stored}
|
||||
|
||||
|
||||
@app.get("/chunks/{user_id}/count")
|
||||
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
|
||||
"""Count conversation chunks for a user."""
|
||||
await _ensure_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await _set_rls_user(conn, user_id)
|
||||
count = await conn.fetchval(
|
||||
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
|
||||
)
|
||||
return {"user_id": user_id, "count": count}
|
||||
|
||||
|
||||
# --- Scheduled Messages ---
|
||||
|
||||
import calendar
|
||||
import datetime
|
||||
|
||||
|
||||
def _compute_repeat_interval(pattern: str) -> int:
|
||||
"""Compute repeat_interval_seconds from pattern name."""
|
||||
return {
|
||||
"once": 0,
|
||||
"daily": 86400,
|
||||
"weekly": 604800,
|
||||
"weekdays": 86400, # special handling in mark-sent
|
||||
"monthly": 0, # special handling in mark-sent
|
||||
}.get(pattern, 0)
|
||||
|
||||
|
||||
def _next_scheduled_at(current_ts: float, pattern: str) -> float:
|
||||
"""Compute the next scheduled_at timestamp for recurring patterns."""
|
||||
dt = datetime.datetime.fromtimestamp(current_ts, tz=datetime.timezone.utc)
|
||||
|
||||
if pattern == "daily":
|
||||
return current_ts + 86400.0
|
||||
elif pattern == "weekly":
|
||||
return current_ts + 604800.0
|
||||
elif pattern == "weekdays":
|
||||
next_dt = dt + datetime.timedelta(days=1)
|
||||
while next_dt.weekday() >= 5: # Skip Sat(5), Sun(6)
|
||||
next_dt += datetime.timedelta(days=1)
|
||||
return next_dt.timestamp()
|
||||
elif pattern == "monthly":
|
||||
month = dt.month + 1
|
||||
year = dt.year + (month - 1) // 12
|
||||
month = (month - 1) % 12 + 1
|
||||
day = min(dt.day, calendar.monthrange(year, month)[1])
|
||||
return dt.replace(year=year, month=month, day=day).timestamp()
|
||||
return current_ts
|
||||
|
||||
|
||||
MAX_REMINDERS_PER_USER = 50
|
||||
|
||||
|
||||
@app.post("/scheduled/create")
|
||||
async def create_scheduled(req: ScheduleRequest, _: None = Depends(verify_token)):
|
||||
"""Create a new scheduled message/reminder."""
|
||||
now = time.time()
|
||||
if req.scheduled_at <= now:
|
||||
raise HTTPException(400, "scheduled_at must be in the future")
|
||||
|
||||
# Check max reminders per user
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
count = await conn.fetchval(
|
||||
"SELECT count(*) FROM scheduled_messages WHERE user_id = $1 AND status = 'pending'",
|
||||
req.user_id,
|
||||
)
|
||||
if count >= MAX_REMINDERS_PER_USER:
|
||||
raise HTTPException(400, f"Maximum {MAX_REMINDERS_PER_USER} active reminders per user")
|
||||
|
||||
msg_text = req.message_text[:2000] # Truncate long messages
|
||||
interval = _compute_repeat_interval(req.repeat_pattern)
|
||||
|
||||
row_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO scheduled_messages
|
||||
(user_id, room_id, message_text, scheduled_at, created_at, status, repeat_pattern, repeat_interval_seconds)
|
||||
VALUES ($1, $2, $3, $4, $5, 'pending', $6, $7)
|
||||
RETURNING id
|
||||
""",
|
||||
req.user_id, req.room_id, msg_text, req.scheduled_at, now,
|
||||
req.repeat_pattern, interval,
|
||||
)
|
||||
logger.info("Created reminder #%d for %s at %.0f (%s)", row_id, req.user_id, req.scheduled_at, req.repeat_pattern)
|
||||
return {"id": row_id, "created": True}
|
||||
|
||||
|
||||
@app.get("/scheduled/{user_id}")
|
||||
async def list_scheduled(user_id: str, _: None = Depends(verify_token)):
|
||||
"""List all pending/active reminders for a user."""
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, message_text, scheduled_at, repeat_pattern, status
|
||||
FROM scheduled_messages
|
||||
WHERE user_id = $1 AND status = 'pending'
|
||||
ORDER BY scheduled_at
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"reminders": [
|
||||
{
|
||||
"id": r["id"],
|
||||
"message_text": r["message_text"],
|
||||
"scheduled_at": r["scheduled_at"],
|
||||
"repeat_pattern": r["repeat_pattern"],
|
||||
"status": r["status"],
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.delete("/scheduled/{user_id}/{reminder_id}")
|
||||
async def cancel_scheduled(user_id: str, reminder_id: int, _: None = Depends(verify_token)):
|
||||
"""Cancel a reminder. Only the owner can cancel."""
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE scheduled_messages SET status = 'cancelled'
|
||||
WHERE id = $1 AND user_id = $2 AND status = 'pending'
|
||||
""",
|
||||
reminder_id, user_id,
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
if count == 0:
|
||||
raise HTTPException(404, "Reminder not found or already cancelled")
|
||||
logger.info("Cancelled reminder #%d for %s", reminder_id, user_id)
|
||||
return {"cancelled": True, "id": reminder_id}
|
||||
|
||||
|
||||
@app.post("/scheduled/due")
|
||||
async def get_due_messages(_: None = Depends(verify_token)):
|
||||
"""Return all messages that are due (scheduled_at <= now, status = pending)."""
|
||||
now = time.time()
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, user_id, room_id, message_text, scheduled_at, repeat_pattern
|
||||
FROM scheduled_messages
|
||||
WHERE scheduled_at <= $1 AND status = 'pending'
|
||||
ORDER BY scheduled_at
|
||||
LIMIT 100
|
||||
""",
|
||||
now,
|
||||
)
|
||||
return {
|
||||
"due": [
|
||||
{
|
||||
"id": r["id"],
|
||||
"user_id": r["user_id"],
|
||||
"room_id": r["room_id"],
|
||||
"message_text": r["message_text"],
|
||||
"scheduled_at": r["scheduled_at"],
|
||||
"repeat_pattern": r["repeat_pattern"],
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/scheduled/{reminder_id}/mark-sent")
|
||||
async def mark_sent(reminder_id: int, _: None = Depends(verify_token)):
|
||||
"""Mark a reminder as sent. For recurring, compute next scheduled_at."""
|
||||
now = time.time()
|
||||
await _ensure_pool()
|
||||
async with owner_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT repeat_pattern, scheduled_at FROM scheduled_messages WHERE id = $1",
|
||||
reminder_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "Reminder not found")
|
||||
|
||||
if row["repeat_pattern"] == "once":
|
||||
await conn.execute(
|
||||
"UPDATE scheduled_messages SET status = 'sent', last_sent_at = $1 WHERE id = $2",
|
||||
now, reminder_id,
|
||||
)
|
||||
else:
|
||||
next_at = _next_scheduled_at(row["scheduled_at"], row["repeat_pattern"])
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE scheduled_messages
|
||||
SET scheduled_at = $1, last_sent_at = $2
|
||||
WHERE id = $3
|
||||
""",
|
||||
next_at, now, reminder_id,
|
||||
)
|
||||
logger.info("Marked reminder #%d as sent (pattern=%s)", reminder_id, row["repeat_pattern"])
|
||||
return {"marked": True}
|
||||
|
||||
95
memory-service/migrate_encrypt.py
Normal file
95
memory-service/migrate_encrypt.py
Normal file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
"""MAT-107: One-time migration to encrypt existing plaintext memory data.
|
||||
|
||||
Run INSIDE the memory-service container after deploying new code:
|
||||
docker exec -it matrix-ai-agent-memory-service-1 python migrate_encrypt.py
|
||||
|
||||
Connects as owner (memory) to bypass RLS.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
import asyncpg
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
OWNER_DSN = os.environ.get(
|
||||
"OWNER_DATABASE_URL",
|
||||
"postgresql://memory:{password}@memory-db:5432/memories".format(
|
||||
password=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
|
||||
),
|
||||
)
|
||||
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||
|
||||
|
||||
def _derive_user_key(user_id: str) -> bytes:
|
||||
derived = hashlib.pbkdf2_hmac("sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1)
|
||||
return base64.urlsafe_b64encode(derived)
|
||||
|
||||
|
||||
def _encrypt(text: str, user_id: str) -> str:
|
||||
f = Fernet(_derive_user_key(user_id))
|
||||
return f.encrypt(text.encode()).decode()
|
||||
|
||||
|
||||
def _is_encrypted(text: str, user_id: str) -> bool:
|
||||
"""Check if text is already Fernet-encrypted."""
|
||||
try:
|
||||
f = Fernet(_derive_user_key(user_id))
|
||||
f.decrypt(text.encode())
|
||||
return True
|
||||
except (InvalidToken, Exception):
|
||||
return False
|
||||
|
||||
|
||||
async def migrate():
|
||||
if not ENCRYPTION_KEY:
|
||||
print("ERROR: MEMORY_ENCRYPTION_KEY not set")
|
||||
sys.exit(1)
|
||||
|
||||
conn = await asyncpg.connect(OWNER_DSN)
|
||||
|
||||
# Migrate memories
|
||||
rows = await conn.fetch("SELECT id, user_id, fact FROM memories ORDER BY id")
|
||||
print(f"Migrating {len(rows)} memories...")
|
||||
encrypted = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
if _is_encrypted(row["fact"], row["user_id"]):
|
||||
skipped += 1
|
||||
continue
|
||||
enc_fact = _encrypt(row["fact"], row["user_id"])
|
||||
await conn.execute("UPDATE memories SET fact = $1 WHERE id = $2", enc_fact, row["id"])
|
||||
encrypted += 1
|
||||
if encrypted % 100 == 0:
|
||||
print(f" memories: {encrypted}/{len(rows)} encrypted")
|
||||
print(f"Memories done: {encrypted} encrypted, {skipped} already encrypted")
|
||||
|
||||
# Migrate conversation_chunks
|
||||
rows = await conn.fetch("SELECT id, user_id, chunk_text, summary FROM conversation_chunks ORDER BY id")
|
||||
print(f"Migrating {len(rows)} chunks...")
|
||||
encrypted = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
if _is_encrypted(row["chunk_text"], row["user_id"]):
|
||||
skipped += 1
|
||||
continue
|
||||
enc_text = _encrypt(row["chunk_text"], row["user_id"])
|
||||
enc_summary = _encrypt(row["summary"], row["user_id"])
|
||||
await conn.execute(
|
||||
"UPDATE conversation_chunks SET chunk_text = $1, summary = $2 WHERE id = $3",
|
||||
enc_text, enc_summary, row["id"],
|
||||
)
|
||||
encrypted += 1
|
||||
if encrypted % 500 == 0:
|
||||
print(f" chunks: {encrypted}/{len(rows)} encrypted")
|
||||
print(f"Chunks done: {encrypted} encrypted, {skipped} already encrypted")
|
||||
|
||||
await conn.close()
|
||||
print("Migration complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(migrate())
|
||||
@@ -3,3 +3,4 @@ uvicorn>=0.34,<1.0
|
||||
asyncpg>=0.30,<1.0
|
||||
pgvector>=0.3,<1.0
|
||||
httpx>=0.27,<1.0
|
||||
cryptography>=44.0,<45.0
|
||||
|
||||
40
migrate_rls.sql
Normal file
40
migrate_rls.sql
Normal file
@@ -0,0 +1,40 @@
|
||||
-- MAT-107: Row-Level Security for memory tables
|
||||
-- Run as superuser (memory) which owns the tables
|
||||
|
||||
-- Create restricted role for memory-service
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'memory_app') THEN
|
||||
CREATE ROLE memory_app LOGIN PASSWORD 'OhugBZP4g4d7rk3OszOq1Xe3yo7hQwEn';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
|
||||
-- Grant permissions to memory_app
|
||||
GRANT CONNECT ON DATABASE memories TO memory_app;
|
||||
GRANT USAGE ON SCHEMA public TO memory_app;
|
||||
GRANT SELECT, INSERT, DELETE ON memories, conversation_chunks TO memory_app;
|
||||
GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO memory_app;
|
||||
|
||||
-- Ensure tables are owned by memory (superuser) so RLS doesn't apply to owner
|
||||
ALTER TABLE memories OWNER TO memory;
|
||||
ALTER TABLE conversation_chunks OWNER TO memory;
|
||||
|
||||
-- Enable RLS
|
||||
ALTER TABLE memories ENABLE ROW LEVEL SECURITY;
|
||||
ALTER TABLE conversation_chunks ENABLE ROW LEVEL SECURITY;
|
||||
|
||||
-- Drop existing policies if re-running
|
||||
DROP POLICY IF EXISTS user_isolation_memories ON memories;
|
||||
DROP POLICY IF EXISTS user_isolation_chunks ON conversation_chunks;
|
||||
|
||||
-- RLS policies: rows visible only when session var matches user_id
|
||||
-- current_setting with missing_ok=true returns empty string if not set
|
||||
CREATE POLICY user_isolation_memories ON memories
|
||||
USING (user_id = current_setting('app.current_user_id', true));
|
||||
|
||||
CREATE POLICY user_isolation_chunks ON conversation_chunks
|
||||
USING (user_id = current_setting('app.current_user_id', true));
|
||||
|
||||
-- Verify
|
||||
SELECT tablename, rowsecurity FROM pg_tables WHERE tablename IN ('memories', 'conversation_chunks');
|
||||
6
pipelines/__init__.py
Normal file
6
pipelines/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Pipeline orchestration engine for Matrix bot."""
|
||||
|
||||
from .engine import PipelineEngine
|
||||
from .state import PipelineStateManager
|
||||
|
||||
__all__ = ["PipelineEngine", "PipelineStateManager"]
|
||||
20
pipelines/approval.py
Normal file
20
pipelines/approval.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Approval handling — maps Matrix reactions to pipeline approvals."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reaction emoji to response mapping
|
||||
APPROVAL_REACTIONS = {
|
||||
"\U0001f44d": "approve", # thumbs up
|
||||
"\U0001f44e": "decline", # thumbs down
|
||||
"\u2705": "approve", # check mark
|
||||
"\u274c": "decline", # cross mark
|
||||
}
|
||||
|
||||
|
||||
def reaction_to_response(reaction_key: str) -> str | None:
|
||||
"""Map a reaction emoji to an approval response."""
|
||||
# Strip variation selectors (U+FE0E, U+FE0F) — Element often appends these
|
||||
cleaned = reaction_key.replace("\ufe0f", "").replace("\ufe0e", "")
|
||||
return APPROVAL_REACTIONS.get(cleaned) or APPROVAL_REACTIONS.get(reaction_key)
|
||||
407
pipelines/engine.py
Normal file
407
pipelines/engine.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""Pipeline execution engine — runs steps sequentially with output chaining."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import httpx
|
||||
import sentry_sdk
|
||||
|
||||
from .state import PipelineStateManager
|
||||
from .steps import execute_step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Transient errors eligible for retry
|
||||
TRANSIENT_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
)
|
||||
|
||||
# Sentinel for failed step (distinct from None output)
|
||||
_STEP_FAILED = object()
|
||||
|
||||
|
||||
class PipelineEngine:
|
||||
"""Executes pipeline steps sequentially, managing state and output chaining."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: PipelineStateManager,
|
||||
send_text,
|
||||
matrix_client,
|
||||
llm_client=None,
|
||||
default_model: str = "claude-haiku",
|
||||
escalation_model: str = "claude-sonnet",
|
||||
on_approval_registered=None,
|
||||
):
|
||||
self.state = state
|
||||
self.send_text = send_text
|
||||
self.matrix_client = matrix_client
|
||||
self.llm = llm_client
|
||||
self.default_model = default_model
|
||||
self.escalation_model = escalation_model
|
||||
self.on_approval_registered = on_approval_registered # callback(event_id, execution_id)
|
||||
# Track active approval listeners: execution_id -> asyncio.Future
|
||||
self._approval_futures: dict[str, asyncio.Future] = {}
|
||||
|
||||
def render_template(self, template: str, context: dict) -> str:
|
||||
"""Simple Jinja2-like template rendering: {{ step_name.output }}"""
|
||||
def replacer(match):
|
||||
expr = match.group(1).strip()
|
||||
parts = expr.split(".")
|
||||
try:
|
||||
value = context
|
||||
for part in parts:
|
||||
if isinstance(value, dict):
|
||||
value = value[part]
|
||||
else:
|
||||
return match.group(0)
|
||||
return str(value)
|
||||
except (KeyError, TypeError):
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(r"\{\{\s*(.+?)\s*\}\}", replacer, template)
|
||||
|
||||
def evaluate_condition(self, condition: str, context: dict) -> bool:
|
||||
"""Evaluate a simple condition like {{ step.response == 'approve' }}"""
|
||||
rendered = self.render_template(condition, context)
|
||||
# Strip template markers if still present
|
||||
rendered = rendered.strip().strip("{}").strip()
|
||||
# Simple equality check
|
||||
if "==" in rendered:
|
||||
left, right = rendered.split("==", 1)
|
||||
return left.strip().strip("'\"") == right.strip().strip("'\"")
|
||||
if "!=" in rendered:
|
||||
left, right = rendered.split("!=", 1)
|
||||
return left.strip().strip("'\"") != right.strip().strip("'\"")
|
||||
# Truthy check
|
||||
return bool(rendered) and rendered.lower() not in ("false", "none", "0", "")
|
||||
|
||||
async def run(self, pipeline: dict, trigger_data: dict | None = None) -> None:
|
||||
"""Execute a full pipeline run."""
|
||||
pipeline_id = pipeline["id"]
|
||||
pipeline_name = pipeline["name"]
|
||||
target_room = pipeline["targetRoom"]
|
||||
steps = pipeline.get("steps", [])
|
||||
user_id = pipeline.get("userId", "")
|
||||
max_retries = pipeline.get("maxRetries", 0)
|
||||
|
||||
if not steps:
|
||||
logger.warning("Pipeline %s has no steps", pipeline_name)
|
||||
return
|
||||
|
||||
with sentry_sdk.start_transaction(op="pipeline.execute", name=pipeline_name) as txn:
|
||||
txn.set_tag("pipeline_id", pipeline_id)
|
||||
txn.set_tag("user_id", user_id)
|
||||
txn.set_tag("trigger_type", pipeline.get("triggerType", "unknown"))
|
||||
txn.set_tag("step_count", len(steps))
|
||||
|
||||
# Create execution record
|
||||
execution = await self.state.create_execution(pipeline_id, trigger_data)
|
||||
execution_id = execution["id"]
|
||||
txn.set_tag("execution_id", execution_id)
|
||||
|
||||
# Audit: execution started
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "execution_started", message=f"Trigger: {pipeline.get('triggerType', 'manual')}"
|
||||
))
|
||||
|
||||
context: dict[str, dict] = {}
|
||||
if trigger_data:
|
||||
context["trigger"] = trigger_data
|
||||
|
||||
step_results: list[dict] = []
|
||||
|
||||
try:
|
||||
for i, step in enumerate(steps):
|
||||
step_name = step.get("name", f"step_{i}")
|
||||
step_type = step.get("type", "")
|
||||
|
||||
# Resume support: skip already-completed steps
|
||||
if i < len(step_results) and step_results[i].get("status") in ("success", "skipped"):
|
||||
context[step_name] = {
|
||||
"output": step_results[i].get("output", ""),
|
||||
"status": step_results[i]["status"],
|
||||
}
|
||||
continue
|
||||
|
||||
# Evaluate condition
|
||||
condition = step.get("if")
|
||||
if condition and not self.evaluate_condition(condition, context):
|
||||
logger.info("Pipeline %s: skipping step %s (condition not met)", pipeline_name, step_name)
|
||||
result = {"name": step_name, "output": "skipped", "status": "skipped", "timestamp": time.time()}
|
||||
step_results.append(result)
|
||||
context[step_name] = {"output": "skipped", "status": "skipped"}
|
||||
continue
|
||||
|
||||
# Update execution state
|
||||
await self.state.update_execution(
|
||||
execution_id,
|
||||
currentStep=i,
|
||||
stepResults=step_results,
|
||||
state="running",
|
||||
)
|
||||
|
||||
logger.info("Pipeline %s: executing step %s (%s)", pipeline_name, step_name, step_type)
|
||||
|
||||
# Render templates in step config
|
||||
rendered_step = {}
|
||||
for key, value in step.items():
|
||||
if isinstance(value, str):
|
||||
rendered_step[key] = self.render_template(value, context)
|
||||
elif isinstance(value, dict):
|
||||
rendered_step[key] = {
|
||||
k: self.render_template(v, context) if isinstance(v, str) else v
|
||||
for k, v in value.items()
|
||||
}
|
||||
else:
|
||||
rendered_step[key] = value
|
||||
|
||||
# Execute step with retry logic
|
||||
output = await self._execute_step_with_retry(
|
||||
step_type=step_type,
|
||||
step_name=step_name,
|
||||
rendered_step=rendered_step,
|
||||
context=context,
|
||||
target_room=target_room,
|
||||
execution_id=execution_id,
|
||||
timeout_s=step.get("timeout_s", 60),
|
||||
max_retries=max_retries,
|
||||
step_results=step_results,
|
||||
pipeline_name=pipeline_name,
|
||||
)
|
||||
|
||||
if output is _STEP_FAILED:
|
||||
return # step handler already updated state
|
||||
|
||||
result = {
|
||||
"name": step_name,
|
||||
"output": output,
|
||||
"status": "success",
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
step_results.append(result)
|
||||
context[step_name] = {"output": output, "status": "success"}
|
||||
|
||||
# For approval steps, also store the response field
|
||||
if step_type == "approval":
|
||||
context[step_name]["response"] = output
|
||||
|
||||
# Audit: step completed
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "step_completed", step_name=step_name, status="success"
|
||||
))
|
||||
|
||||
# All steps completed
|
||||
await self.state.update_execution(
|
||||
execution_id,
|
||||
state="complete",
|
||||
stepResults=step_results,
|
||||
)
|
||||
logger.info("Pipeline %s completed successfully", pipeline_name)
|
||||
|
||||
# Audit: execution completed
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "execution_completed",
|
||||
message=f"{len(steps)} steps completed",
|
||||
))
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline %s failed unexpectedly: %s", pipeline_name, exc, exc_info=True)
|
||||
sentry_sdk.capture_exception(exc)
|
||||
await self.state.update_execution(
|
||||
execution_id,
|
||||
state="failed",
|
||||
stepResults=step_results,
|
||||
error=str(exc),
|
||||
)
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "execution_failed", message=str(exc)
|
||||
))
|
||||
|
||||
async def _execute_step_with_retry(
|
||||
self, *, step_type, step_name, rendered_step, context, target_room,
|
||||
execution_id, timeout_s, max_retries, step_results, pipeline_name,
|
||||
):
|
||||
"""Execute a step with retry on transient failures. Returns output or _STEP_FAILED."""
|
||||
last_exc = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
if step_type == "approval":
|
||||
output = await self._execute_approval_step(
|
||||
rendered_step, target_room, execution_id, timeout_s
|
||||
)
|
||||
else:
|
||||
output = await asyncio.wait_for(
|
||||
execute_step(
|
||||
step_type=step_type,
|
||||
step_config=rendered_step,
|
||||
context=context,
|
||||
send_text=self.send_text,
|
||||
target_room=target_room,
|
||||
llm=self.llm,
|
||||
default_model=self.default_model,
|
||||
escalation_model=self.escalation_model,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
return output
|
||||
|
||||
except TRANSIENT_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt < max_retries:
|
||||
backoff = 2 ** attempt # 1s, 2s, 4s
|
||||
logger.warning(
|
||||
"Pipeline %s: step %s transient failure (attempt %d/%d), retrying in %ds: %s",
|
||||
pipeline_name, step_name, attempt + 1, max_retries + 1, backoff, exc,
|
||||
)
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "retry_attempted", step_name=step_name,
|
||||
message=f"Attempt {attempt + 1}/{max_retries + 1}: {exc}",
|
||||
))
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
# Exhausted retries — fall through to failure
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Step {step_name} timed out after {timeout_s}s"
|
||||
logger.error("Pipeline %s: %s", pipeline_name, error_msg)
|
||||
sentry_sdk.capture_message(error_msg, level="error")
|
||||
step_results.append({
|
||||
"name": step_name, "output": None, "status": "timeout",
|
||||
"error": error_msg, "timestamp": time.time(),
|
||||
})
|
||||
await self.state.update_execution(
|
||||
execution_id, state="failed", stepResults=step_results, error=error_msg,
|
||||
)
|
||||
await self.send_text(target_room, f"**{pipeline_name}**: {error_msg}")
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "step_failed", step_name=step_name, status="timeout", message=error_msg,
|
||||
))
|
||||
return _STEP_FAILED
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = f"Step {step_name} failed: {exc}"
|
||||
logger.error("Pipeline %s: %s", pipeline_name, error_msg, exc_info=True)
|
||||
sentry_sdk.capture_exception(exc)
|
||||
step_results.append({
|
||||
"name": step_name, "output": None, "status": "error",
|
||||
"error": str(exc), "timestamp": time.time(),
|
||||
})
|
||||
await self.state.update_execution(
|
||||
execution_id, state="failed", stepResults=step_results, error=error_msg,
|
||||
)
|
||||
await self.send_text(target_room, f"**{pipeline_name}**: {error_msg}")
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "step_failed", step_name=step_name, status="error", message=str(exc),
|
||||
))
|
||||
return _STEP_FAILED
|
||||
|
||||
# Retries exhausted on transient failure
|
||||
error_msg = f"Step {step_name} failed after {max_retries + 1} attempts: {last_exc}"
|
||||
logger.error("Pipeline %s: %s", pipeline_name, error_msg)
|
||||
sentry_sdk.capture_exception(last_exc)
|
||||
step_results.append({
|
||||
"name": step_name, "output": None, "status": "error",
|
||||
"error": error_msg, "timestamp": time.time(), "retries": max_retries,
|
||||
})
|
||||
await self.state.update_execution(
|
||||
execution_id, state="failed", stepResults=step_results, error=error_msg,
|
||||
)
|
||||
await self.send_text(target_room, f"**{pipeline_name}**: {error_msg}")
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "step_failed", step_name=step_name, status="error", message=error_msg,
|
||||
))
|
||||
return _STEP_FAILED
|
||||
|
||||
@staticmethod
|
||||
def _md_to_html(text: str) -> str:
|
||||
"""Convert basic markdown to HTML for Matrix formatted_body."""
|
||||
import re as _re
|
||||
html = text
|
||||
# Bold: **text** -> <strong>text</strong>
|
||||
html = _re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', html)
|
||||
# Italic: *text* -> <em>text</em>
|
||||
html = _re.sub(r'(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)', r'<em>\1</em>', html)
|
||||
# Code: `text` -> <code>text</code>
|
||||
html = _re.sub(r'`(.+?)`', r'<code>\1</code>', html)
|
||||
# Newlines
|
||||
html = html.replace("\n", "<br>")
|
||||
return html
|
||||
|
||||
async def _execute_approval_step(
|
||||
self, step: dict, target_room: str, execution_id: str, timeout_s: int
|
||||
) -> str:
|
||||
"""Post approval message and wait for reaction."""
|
||||
message = step.get("message", "Approve this action?")
|
||||
body = f"**Approval Required**\n\n{message}\n\nReact with \U0001f44d to approve or \U0001f44e to decline."
|
||||
html = self._md_to_html(body)
|
||||
|
||||
# Send message and get event ID
|
||||
resp = await self.matrix_client.room_send(
|
||||
room_id=target_room,
|
||||
message_type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": body,
|
||||
"format": "org.matrix.custom.html",
|
||||
"formatted_body": html,
|
||||
},
|
||||
)
|
||||
event_id = resp.event_id if hasattr(resp, "event_id") else None
|
||||
|
||||
if not event_id:
|
||||
raise RuntimeError("Failed to send approval message")
|
||||
|
||||
# Notify bot of approval event mapping
|
||||
if self.on_approval_registered:
|
||||
self.on_approval_registered(event_id, execution_id)
|
||||
|
||||
# Register approval listener
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._approval_futures[execution_id] = future
|
||||
|
||||
# Update execution with approval tracking
|
||||
expires_at = (datetime.now(timezone.utc) + timedelta(seconds=timeout_s)).isoformat()
|
||||
await self.state.update_execution(
|
||||
execution_id,
|
||||
state="waiting_approval",
|
||||
approvalMsgId=event_id,
|
||||
approvalExpiresAt=expires_at,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(future, timeout=timeout_s)
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "approval_resolved", message=f"Response: {result}",
|
||||
))
|
||||
return result # "approve" or "decline"
|
||||
except asyncio.TimeoutError:
|
||||
sentry_sdk.capture_message(
|
||||
f"Approval timeout: execution={execution_id}",
|
||||
level="warning",
|
||||
)
|
||||
await self.send_text(target_room, "Approval timed out. Pipeline aborted.")
|
||||
await self.state.update_execution(execution_id, state="aborted", error="Approval timed out")
|
||||
asyncio.create_task(self.state.log_event(
|
||||
execution_id, "step_failed", step_name="approval", status="timeout",
|
||||
message="Approval timed out",
|
||||
))
|
||||
raise
|
||||
finally:
|
||||
self._approval_futures.pop(execution_id, None)
|
||||
|
||||
def resolve_approval(self, execution_id: str, response: str) -> bool:
|
||||
"""Resolve a pending approval. Called by reaction handler."""
|
||||
future = self._approval_futures.get(execution_id)
|
||||
if future and not future.done():
|
||||
future.set_result(response)
|
||||
return True
|
||||
return False
|
||||
102
pipelines/state.py
Normal file
102
pipelines/state.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Pipeline state management — syncs with matrixhost portal API."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineStateManager:
|
||||
"""Manages pipeline state via portal API."""
|
||||
|
||||
def __init__(self, portal_url: str, api_key: str):
|
||||
self.portal_url = portal_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
|
||||
async def fetch_active_pipelines(self) -> list[dict]:
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.portal_url}/api/pipelines/active",
|
||||
headers={"x-api-key": self.api_key},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("pipelines", [])
|
||||
|
||||
async def create_execution(self, pipeline_id: str, trigger_data: dict | None = None) -> dict:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self.portal_url}/api/pipelines/{pipeline_id}/execution",
|
||||
headers={"x-api-key": self.api_key},
|
||||
json={"triggerData": trigger_data},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["execution"]
|
||||
|
||||
async def update_execution(self, execution_id: str, **kwargs) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.put(
|
||||
f"{self.portal_url}/api/pipelines/executions/{execution_id}",
|
||||
headers={"x-api-key": self.api_key},
|
||||
json=kwargs,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to update execution %s", execution_id, exc_info=True)
|
||||
|
||||
async def fetch_pending_approvals(self) -> list[dict]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.portal_url}/api/pipelines/executions/pending",
|
||||
headers={"x-api-key": self.api_key},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("executions", [])
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch pending approvals", exc_info=True)
|
||||
return []
|
||||
|
||||
async def count_active_executions(self, user_id: str) -> int:
|
||||
"""Count running/waiting_approval executions for a user."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.portal_url}/api/pipelines/executions/active",
|
||||
headers={"x-api-key": self.api_key},
|
||||
params={"userId": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("count", 0)
|
||||
except Exception:
|
||||
logger.warning("Failed to count active executions for user %s", user_id, exc_info=True)
|
||||
return 0
|
||||
|
||||
async def log_event(
|
||||
self, execution_id: str, action: str, *,
|
||||
step_name: str | None = None,
|
||||
status: str | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Log an audit event for a pipeline execution (fire-and-forget)."""
|
||||
try:
|
||||
payload = {"action": action}
|
||||
if step_name:
|
||||
payload["stepName"] = step_name
|
||||
if status:
|
||||
payload["status"] = status
|
||||
if message:
|
||||
payload["message"] = message
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.post(
|
||||
f"{self.portal_url}/api/pipelines/executions/{execution_id}/audit-log",
|
||||
headers={"x-api-key": self.api_key},
|
||||
json=payload,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to log audit event %s for %s", action, execution_id)
|
||||
47
pipelines/steps/__init__.py
Normal file
47
pipelines/steps/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Step type registry and dispatcher."""
|
||||
|
||||
import logging
|
||||
|
||||
from .script import execute_script
|
||||
from .claude_prompt import execute_claude_prompt
|
||||
from .template import execute_template
|
||||
from .api_call import execute_api_call
|
||||
from .skyvern import execute_skyvern
|
||||
from .pitrader_step import execute_pitrader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STEP_EXECUTORS = {
|
||||
"script": execute_script,
|
||||
"claude_prompt": execute_claude_prompt,
|
||||
"template": execute_template,
|
||||
"api_call": execute_api_call,
|
||||
"skyvern": execute_skyvern,
|
||||
"pitrader_script": execute_pitrader,
|
||||
}
|
||||
|
||||
|
||||
async def execute_step(
|
||||
step_type: str,
|
||||
step_config: dict,
|
||||
context: dict,
|
||||
send_text,
|
||||
target_room: str,
|
||||
llm=None,
|
||||
default_model: str = "claude-haiku",
|
||||
escalation_model: str = "claude-sonnet",
|
||||
) -> str:
|
||||
"""Execute a pipeline step and return its output as a string."""
|
||||
executor = STEP_EXECUTORS.get(step_type)
|
||||
if not executor:
|
||||
raise ValueError(f"Unknown step type: {step_type}")
|
||||
|
||||
return await executor(
|
||||
config=step_config,
|
||||
context=context,
|
||||
send_text=send_text,
|
||||
target_room=target_room,
|
||||
llm=llm,
|
||||
default_model=default_model,
|
||||
escalation_model=escalation_model,
|
||||
)
|
||||
33
pipelines/steps/api_call.py
Normal file
33
pipelines/steps/api_call.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""API call step — make HTTP requests."""
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def execute_api_call(config: dict, **_kwargs) -> str:
|
||||
"""Make an HTTP request and return the response body."""
|
||||
url = config.get("url", "")
|
||||
if not url:
|
||||
raise ValueError("api_call step requires 'url' field")
|
||||
|
||||
method = config.get("method", "GET").upper()
|
||||
headers = config.get("headers", {})
|
||||
body = config.get("body")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
resp = await client.get(url, headers=headers)
|
||||
elif method == "POST":
|
||||
resp = await client.post(url, headers=headers, content=body)
|
||||
elif method == "PUT":
|
||||
resp = await client.put(url, headers=headers, content=body)
|
||||
elif method == "DELETE":
|
||||
resp = await client.delete(url, headers=headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
62
pipelines/steps/claude_prompt.py
Normal file
62
pipelines/steps/claude_prompt.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Claude prompt step — call LLM via LiteLLM proxy."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def execute_claude_prompt(
|
||||
config: dict,
|
||||
context: dict | None = None,
|
||||
llm=None,
|
||||
default_model: str = "claude-haiku",
|
||||
escalation_model: str = "claude-sonnet",
|
||||
**_kwargs,
|
||||
) -> str:
|
||||
"""Send a prompt to Claude and return the response.
|
||||
|
||||
Supports vision: if config contains 'image_b64' or trigger context has
|
||||
'image_b64', the image is included as a vision content block.
|
||||
"""
|
||||
if not llm:
|
||||
raise RuntimeError("LLM client not configured")
|
||||
|
||||
prompt = config.get("prompt", "")
|
||||
if not prompt:
|
||||
raise ValueError("claude_prompt step requires 'prompt' field")
|
||||
|
||||
model_name = config.get("model", "default")
|
||||
model = escalation_model if model_name == "escalation" else default_model
|
||||
|
||||
# Check for image data (from config or trigger context)
|
||||
image_b64 = config.get("image_b64", "")
|
||||
image_mime = config.get("image_mime", "image/png")
|
||||
if not image_b64 and context:
|
||||
trigger = context.get("trigger", {})
|
||||
image_b64 = trigger.get("image_b64", "")
|
||||
image_mime = trigger.get("mime_type", "image/png")
|
||||
|
||||
# Build message content
|
||||
if image_b64:
|
||||
content = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{image_mime};base64,{image_b64}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
]
|
||||
else:
|
||||
content = prompt
|
||||
|
||||
response = await llm.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or ""
|
||||
110
pipelines/steps/pitrader_step.py
Normal file
110
pipelines/steps/pitrader_step.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""PITrader step — execute PITrader scripts with JSON output capture."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PITRADER_DIR = os.environ.get("PITRADER_DIR", os.path.expanduser("~/Development/Apps/PITrader"))
|
||||
|
||||
|
||||
async def execute_pitrader(config: dict, **_kwargs) -> str:
|
||||
"""Execute a PITrader script and return JSON output.
|
||||
|
||||
Config fields:
|
||||
script: Script path relative to PITrader dir (e.g., "scripts/pi-scan")
|
||||
or absolute path
|
||||
args: List of CLI arguments (default: [])
|
||||
timeout_s: Override timeout in seconds (default: 300)
|
||||
json_output: If true, append --json flag (default: true)
|
||||
env: Extra environment variables dict (default: {})
|
||||
"""
|
||||
script = config.get("script", "")
|
||||
if not script:
|
||||
raise ValueError("pitrader_script step requires 'script' field")
|
||||
|
||||
args = config.get("args", [])
|
||||
if isinstance(args, str):
|
||||
args = args.split()
|
||||
json_output = config.get("json_output", True)
|
||||
extra_env = config.get("env", {})
|
||||
timeout_s = config.get("timeout_s", 300)
|
||||
|
||||
# Build command
|
||||
if not os.path.isabs(script):
|
||||
script = os.path.join(PITRADER_DIR, script)
|
||||
|
||||
cmd_parts = [script] + args
|
||||
if json_output and "--json" not in args:
|
||||
cmd_parts.append("--json")
|
||||
|
||||
cmd = " ".join(cmd_parts)
|
||||
logger.info("PITrader step: %s (cwd=%s, timeout=%ds)", cmd, PITRADER_DIR, timeout_s)
|
||||
|
||||
# Build environment: inherit current + PYTHONPATH for imports + extras
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = PITRADER_DIR + os.pathsep + env.get("PYTHONPATH", "")
|
||||
env.update(extra_env)
|
||||
|
||||
# Fetch vault credentials if not already in env
|
||||
if "ETORO_API_KEY" not in env:
|
||||
try:
|
||||
vault_proc = await asyncio.create_subprocess_exec(
|
||||
"vault", "get", "etoro.api_key",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, _ = await asyncio.wait_for(vault_proc.communicate(), timeout=10)
|
||||
if vault_proc.returncode == 0:
|
||||
env["ETORO_API_KEY"] = stdout.decode().strip()
|
||||
except Exception:
|
||||
logger.debug("Could not fetch etoro.api_key from vault")
|
||||
|
||||
if "ETORO_USER_KEY" not in env:
|
||||
try:
|
||||
vault_proc = await asyncio.create_subprocess_exec(
|
||||
"vault", "get", "etoro.user_key",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, _ = await asyncio.wait_for(vault_proc.communicate(), timeout=10)
|
||||
if vault_proc.returncode == 0:
|
||||
env["ETORO_USER_KEY"] = stdout.decode().strip()
|
||||
except Exception:
|
||||
logger.debug("Could not fetch etoro.user_key from vault")
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=PITRADER_DIR,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout_s)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
raise RuntimeError(f"PITrader script timed out after {timeout_s}s: {cmd}")
|
||||
|
||||
output = stdout.decode("utf-8", errors="replace").strip()
|
||||
err_output = stderr.decode("utf-8", errors="replace").strip()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"PITrader script exited with code {proc.returncode}: {err_output or output}"
|
||||
)
|
||||
|
||||
# Validate JSON output if expected
|
||||
if json_output and output:
|
||||
try:
|
||||
json.loads(output)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("PITrader output is not valid JSON, returning raw output")
|
||||
|
||||
if err_output:
|
||||
logger.debug("PITrader stderr: %s", err_output[:500])
|
||||
|
||||
return output
|
||||
27
pipelines/steps/script.py
Normal file
27
pipelines/steps/script.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Script step — execute a shell command and capture output."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def execute_script(config: dict, **_kwargs) -> str:
|
||||
"""Execute a shell script and return stdout."""
|
||||
script = config.get("script", "")
|
||||
if not script:
|
||||
raise ValueError("Script step requires 'script' field")
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
script,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
output = stdout.decode("utf-8", errors="replace").strip()
|
||||
if proc.returncode != 0:
|
||||
err = stderr.decode("utf-8", errors="replace").strip()
|
||||
raise RuntimeError(f"Script exited with code {proc.returncode}: {err or output}")
|
||||
|
||||
return output
|
||||
105
pipelines/steps/skyvern.py
Normal file
105
pipelines/steps/skyvern.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Skyvern step — browser automation via Skyvern API for pipeline execution."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SKYVERN_BASE_URL = os.environ.get("SKYVERN_BASE_URL", "http://skyvern:8000")
|
||||
SKYVERN_API_KEY = os.environ.get("SKYVERN_API_KEY", "")
|
||||
|
||||
POLL_INTERVAL = 5
|
||||
MAX_POLL_TIME = 300
|
||||
|
||||
|
||||
async def execute_skyvern(config: dict, send_text=None, target_room: str = "", **_kwargs) -> str:
|
||||
"""Dispatch a browser task to Skyvern and return extracted data.
|
||||
|
||||
Config fields:
|
||||
url: target URL (required)
|
||||
goal: navigation goal / prompt (required)
|
||||
data_extraction_goal: what to extract (optional, added to prompt)
|
||||
extraction_schema: JSON schema for structured extraction (optional)
|
||||
credential_id: Skyvern credential ID for login (optional)
|
||||
totp_identifier: email/phone for TOTP (optional)
|
||||
timeout_s: max poll time in seconds (optional, default 300)
|
||||
"""
|
||||
if not SKYVERN_API_KEY:
|
||||
raise RuntimeError("SKYVERN_API_KEY not configured")
|
||||
|
||||
url = config.get("url", "")
|
||||
goal = config.get("goal", "")
|
||||
data_extraction_goal = config.get("data_extraction_goal", "")
|
||||
extraction_schema = config.get("extraction_schema")
|
||||
credential_id = config.get("credential_id")
|
||||
totp_identifier = config.get("totp_identifier")
|
||||
max_poll = config.get("timeout_s", MAX_POLL_TIME)
|
||||
|
||||
if not url or not goal:
|
||||
raise ValueError("Skyvern step requires 'url' and 'goal' in config")
|
||||
|
||||
payload: dict = {
|
||||
"url": url,
|
||||
"navigation_goal": goal,
|
||||
"data_extraction_goal": data_extraction_goal or goal,
|
||||
}
|
||||
if extraction_schema:
|
||||
if isinstance(extraction_schema, str):
|
||||
extraction_schema = json.loads(extraction_schema)
|
||||
payload["extracted_information_schema"] = extraction_schema
|
||||
if credential_id:
|
||||
payload["credential_id"] = credential_id
|
||||
if totp_identifier:
|
||||
payload["totp_identifier"] = totp_identifier
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": SKYVERN_API_KEY,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(
|
||||
f"{SKYVERN_BASE_URL}/api/v1/tasks",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
run_id = resp.json()["task_id"]
|
||||
|
||||
logger.info("Skyvern pipeline task created: %s", run_id)
|
||||
|
||||
if send_text and target_room:
|
||||
await send_text(target_room, f"Browser task started for {url}...")
|
||||
|
||||
# Poll for completion
|
||||
elapsed = 0
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
while elapsed < max_poll:
|
||||
resp = await client.get(
|
||||
f"{SKYVERN_BASE_URL}/api/v1/tasks/{run_id}",
|
||||
headers={"x-api-key": SKYVERN_API_KEY},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
status = data.get("status", "")
|
||||
|
||||
if status == "completed":
|
||||
extracted = data.get("extracted_information") or data.get("extracted_data")
|
||||
if extracted is None:
|
||||
return "Task completed, no data extracted."
|
||||
if isinstance(extracted, (dict, list)):
|
||||
return json.dumps(extracted, ensure_ascii=False)
|
||||
return str(extracted)
|
||||
|
||||
if status in ("failed", "terminated", "timed_out"):
|
||||
error = data.get("error") or data.get("failure_reason") or status
|
||||
raise RuntimeError(f"Skyvern task {status}: {error}")
|
||||
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
elapsed += POLL_INTERVAL
|
||||
|
||||
raise TimeoutError(f"Skyvern task {run_id} did not complete within {max_poll}s")
|
||||
18
pipelines/steps/template.py
Normal file
18
pipelines/steps/template.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Template step — format and post a message to the target room."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def execute_template(config: dict, send_text=None, target_room: str = "", **_kwargs) -> str:
|
||||
"""Render a template message and post it to the target room."""
|
||||
template = config.get("template", config.get("message", ""))
|
||||
if not template:
|
||||
raise ValueError("template step requires 'template' or 'message' field")
|
||||
|
||||
# Template is already rendered by the engine before reaching here
|
||||
if send_text and target_room:
|
||||
await send_text(target_room, template)
|
||||
|
||||
return template
|
||||
186
rag_key_manager.py
Normal file
186
rag_key_manager.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
RAG Document Key Manager — stores per-user encryption keys in Matrix E2EE rooms.
|
||||
|
||||
The key is stored as an encrypted event in a private room that only the bot can access.
|
||||
On startup, the bot syncs the room and re-injects the key into the RAG service
|
||||
via the portal proxy (since RAG service is localhost-only on the customer VM).
|
||||
No plaintext keys are ever written to disk.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import logging
|
||||
import httpx
|
||||
from nio.api import RoomVisibility
|
||||
|
||||
logger = logging.getLogger("matrix-ai-bot")
|
||||
|
||||
KEY_EVENT_TYPE = "eu.matrixhost.rag_document_key"
|
||||
KEY_ROOM_TOPIC = "RAG Document Encryption Keys \u2014 DO NOT LEAVE"
|
||||
|
||||
|
||||
class RAGKeyManager:
|
||||
"""Manages per-user document encryption keys via Matrix E2EE."""
|
||||
|
||||
def __init__(self, client, portal_url: str, bot_api_key: str):
|
||||
self.client = client
|
||||
self.portal_url = portal_url.rstrip("/") if portal_url else ""
|
||||
self.bot_api_key = bot_api_key
|
||||
self._key_room_id: str | None = None
|
||||
|
||||
async def ensure_rag_key(self, seed_key_hex: str | None = None) -> bool:
|
||||
"""Ensure RAG service has encryption key loaded.
|
||||
|
||||
Args:
|
||||
seed_key_hex: Existing key to migrate into Matrix storage (one-time).
|
||||
"""
|
||||
if not self.portal_url:
|
||||
logger.warning("[rag-key] No portal URL configured")
|
||||
return False
|
||||
|
||||
# Check if RAG already has a key
|
||||
if await self._rag_has_key():
|
||||
logger.info("[rag-key] RAG service already has key loaded")
|
||||
room_id = await self._find_or_create_key_room()
|
||||
if room_id:
|
||||
existing = await self._load_key_from_room(room_id)
|
||||
if not existing and seed_key_hex:
|
||||
await self._store_key_in_room(room_id, seed_key_hex)
|
||||
logger.info("[rag-key] Migrated existing key into Matrix E2EE room")
|
||||
return True
|
||||
|
||||
# Find or create the key storage room
|
||||
room_id = await self._find_or_create_key_room()
|
||||
if not room_id:
|
||||
logger.error("[rag-key] Failed to find or create key room")
|
||||
return False
|
||||
|
||||
# Try to load existing key from room
|
||||
key_hex = await self._load_key_from_room(room_id)
|
||||
|
||||
if key_hex:
|
||||
logger.info("[rag-key] Loaded existing key from Matrix room")
|
||||
elif seed_key_hex:
|
||||
key_hex = seed_key_hex
|
||||
stored = await self._store_key_in_room(room_id, key_hex)
|
||||
if not stored:
|
||||
logger.error("[rag-key] Failed to store seed key in Matrix room")
|
||||
return False
|
||||
logger.info("[rag-key] Stored migration seed key in Matrix E2EE room")
|
||||
else:
|
||||
key_hex = secrets.token_hex(32)
|
||||
stored = await self._store_key_in_room(room_id, key_hex)
|
||||
if not stored:
|
||||
logger.error("[rag-key] Failed to store new key in Matrix room")
|
||||
return False
|
||||
logger.info("[rag-key] Generated and stored new encryption key")
|
||||
|
||||
return await self._inject_key(key_hex)
|
||||
|
||||
async def _rag_has_key(self) -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.portal_url}/api/bot/rag-key",
|
||||
headers={"Authorization": f"Bearer {self.bot_api_key}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("has_key", False)
|
||||
except Exception as e:
|
||||
logger.debug("[rag-key] Health check failed: %s", e)
|
||||
return False
|
||||
|
||||
async def _find_or_create_key_room(self) -> str | None:
|
||||
for room_id, room in self.client.rooms.items():
|
||||
if room.topic == KEY_ROOM_TOPIC:
|
||||
self._key_room_id = room_id
|
||||
logger.info("[rag-key] Found existing key room: %s", room_id)
|
||||
return room_id
|
||||
|
||||
try:
|
||||
from nio import EnableEncryptionBuilder
|
||||
initial_state = [EnableEncryptionBuilder().as_dict()]
|
||||
except ImportError:
|
||||
initial_state = [{
|
||||
"type": "m.room.encryption",
|
||||
"state_key": "",
|
||||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||||
}]
|
||||
|
||||
resp = await self.client.room_create(
|
||||
name="RAG Key Storage",
|
||||
topic=KEY_ROOM_TOPIC,
|
||||
invite=[],
|
||||
initial_state=initial_state,
|
||||
visibility=RoomVisibility.private,
|
||||
)
|
||||
|
||||
if hasattr(resp, "room_id"):
|
||||
self._key_room_id = resp.room_id
|
||||
logger.info("[rag-key] Created new key room: %s", resp.room_id)
|
||||
return resp.room_id
|
||||
|
||||
logger.error("[rag-key] Failed to create key room: %s", resp)
|
||||
return None
|
||||
|
||||
async def _load_key_from_room(self, room_id: str) -> str | None:
|
||||
try:
|
||||
resp = await self.client.room_messages(
|
||||
room_id, start="", limit=50, direction="b",
|
||||
)
|
||||
if not hasattr(resp, "chunk"):
|
||||
return None
|
||||
|
||||
for event in resp.chunk:
|
||||
if hasattr(event, "source"):
|
||||
source = event.source
|
||||
if source.get("type") == KEY_EVENT_TYPE:
|
||||
key = source.get("content", {}).get("key_hex")
|
||||
if key:
|
||||
return key
|
||||
|
||||
if hasattr(event, "type") and event.type == KEY_EVENT_TYPE:
|
||||
if hasattr(event, "content"):
|
||||
key = event.content.get("key_hex")
|
||||
if key:
|
||||
return key
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("[rag-key] Failed to load key from room: %s", e)
|
||||
return None
|
||||
|
||||
async def _store_key_in_room(self, room_id: str, key_hex: str) -> bool:
|
||||
try:
|
||||
content = {
|
||||
"key_hex": key_hex,
|
||||
"algorithm": "aes-256-gcm",
|
||||
"purpose": "rag-document-encryption",
|
||||
"msgtype": "eu.matrixhost.rag_key",
|
||||
}
|
||||
resp = await self.client.room_send(
|
||||
room_id, message_type=KEY_EVENT_TYPE,
|
||||
content=content, ignore_unverified_devices=True,
|
||||
)
|
||||
if hasattr(resp, "event_id"):
|
||||
logger.info("[rag-key] Key stored as event %s", resp.event_id)
|
||||
return True
|
||||
logger.error("[rag-key] Failed to send key event: %s", resp)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("[rag-key] Failed to store key: %s", e)
|
||||
return False
|
||||
|
||||
async def _inject_key(self, key_hex: str) -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self.portal_url}/api/bot/rag-key",
|
||||
json={"key_hex": key_hex},
|
||||
headers={"Authorization": f"Bearer {self.bot_api_key}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info("[rag-key] Key injected into RAG service via portal proxy")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("[rag-key] Failed to inject key: %s", e)
|
||||
return False
|
||||
3
requirements-test.txt
Normal file
3
requirements-test.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pytest>=7.4,<9.0
|
||||
pytest-asyncio>=0.21,<1.0
|
||||
pytest-cov>=4.1,<6.0
|
||||
@@ -9,3 +9,8 @@ canonicaljson>=2.0,<3.0
|
||||
httpx>=0.27,<1.0
|
||||
openai>=2.0,<3.0
|
||||
pymupdf>=1.24,<2.0
|
||||
python-docx>=1.0,<2.0
|
||||
Pillow>=10.0,<12.0
|
||||
beautifulsoup4>=4.12
|
||||
lxml>=5.0
|
||||
sentry-sdk>=2.0,<3.0
|
||||
|
||||
230
scripts/matrix_device_cleanup.py
Executable file
230
scripts/matrix_device_cleanup.py
Executable file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Clean up stale Matrix devices via Synapse Admin API.
|
||||
|
||||
Usage:
|
||||
python matrix_device_cleanup.py --user @admin:agiliton.eu --keep 1 --dry-run
|
||||
python matrix_device_cleanup.py --user @admin:agiliton.eu --keep 1
|
||||
python matrix_device_cleanup.py --auto --max-age-days 30 --keep 3
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 100
|
||||
BATCH_DELAY = 1.0 # seconds between batch deletions
|
||||
|
||||
|
||||
async def get_admin_token(homeserver: str) -> str:
|
||||
"""Get Synapse admin token from env or vault."""
|
||||
token = os.environ.get("SYNAPSE_ADMIN_TOKEN")
|
||||
if token:
|
||||
return token
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["vault", "get", "matrix.agiliton.admin_token"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
raise RuntimeError(
|
||||
"No admin token found. Set SYNAPSE_ADMIN_TOKEN or store in vault "
|
||||
"as matrix.agiliton.admin_token"
|
||||
)
|
||||
|
||||
|
||||
async def list_devices(
|
||||
client: httpx.AsyncClient, homeserver: str, headers: dict, user_id: str,
|
||||
) -> list[dict]:
|
||||
"""List all devices for a user via Synapse Admin API."""
|
||||
encoded_user = quote(user_id, safe="")
|
||||
resp = await client.get(
|
||||
f"{homeserver}/_synapse/admin/v2/users/{encoded_user}/devices",
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("devices", [])
|
||||
|
||||
|
||||
async def delete_devices_batch(
|
||||
client: httpx.AsyncClient,
|
||||
homeserver: str,
|
||||
headers: dict,
|
||||
user_id: str,
|
||||
device_ids: list[str],
|
||||
) -> int:
|
||||
"""Bulk-delete devices. Returns count deleted."""
|
||||
encoded_user = quote(user_id, safe="")
|
||||
resp = await client.post(
|
||||
f"{homeserver}/_synapse/admin/v2/users/{encoded_user}/delete_devices",
|
||||
headers=headers,
|
||||
json={"devices": device_ids},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return len(device_ids)
|
||||
|
||||
|
||||
async def cleanup_devices(
|
||||
homeserver: str,
|
||||
user_id: str,
|
||||
keep: int = 1,
|
||||
max_age_days: int | None = None,
|
||||
dry_run: bool = False,
|
||||
skip_device_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Remove stale devices, keeping the N most recently active.
|
||||
|
||||
Returns summary dict with counts.
|
||||
"""
|
||||
token = await get_admin_token(homeserver)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
skip = set(skip_device_ids or [])
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
devices = await list_devices(client, homeserver, headers, user_id)
|
||||
|
||||
if not devices:
|
||||
logger.info("No devices found for %s", user_id)
|
||||
return {"total": 0, "kept": 0, "deleted": 0}
|
||||
|
||||
# Sort by last_seen_ts descending (most recent first), treat None as 0
|
||||
devices.sort(key=lambda d: d.get("last_seen_ts") or 0, reverse=True)
|
||||
|
||||
# Determine which to keep
|
||||
to_keep = []
|
||||
to_delete = []
|
||||
|
||||
for i, dev in enumerate(devices):
|
||||
dev_id = dev["device_id"]
|
||||
last_seen = dev.get("last_seen_ts") or 0
|
||||
|
||||
# Always skip explicitly protected devices
|
||||
if dev_id in skip:
|
||||
to_keep.append(dev)
|
||||
continue
|
||||
|
||||
# Keep the top N most recent
|
||||
if i < keep:
|
||||
to_keep.append(dev)
|
||||
continue
|
||||
|
||||
# If max_age_days set, only delete devices older than threshold
|
||||
if max_age_days is not None and last_seen > 0:
|
||||
age_days = (time.time() * 1000 - last_seen) / (86400 * 1000)
|
||||
if age_days < max_age_days:
|
||||
to_keep.append(dev)
|
||||
continue
|
||||
|
||||
to_delete.append(dev)
|
||||
|
||||
logger.info(
|
||||
"User %s: %d total devices, keeping %d, deleting %d%s",
|
||||
user_id, len(devices), len(to_keep), len(to_delete),
|
||||
" (DRY RUN)" if dry_run else "",
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
for dev in to_delete[:10]:
|
||||
last = dev.get("last_seen_ts") or 0
|
||||
age = f"{(time.time() * 1000 - last) / (86400 * 1000):.1f}d" if last else "never"
|
||||
logger.info(
|
||||
" Would delete: %s (display: %s, last seen: %s ago)",
|
||||
dev["device_id"],
|
||||
dev.get("display_name", ""),
|
||||
age,
|
||||
)
|
||||
if len(to_delete) > 10:
|
||||
logger.info(" ... and %d more", len(to_delete) - 10)
|
||||
return {
|
||||
"total": len(devices),
|
||||
"kept": len(to_keep),
|
||||
"deleted": 0,
|
||||
"would_delete": len(to_delete),
|
||||
}
|
||||
|
||||
# Delete in batches
|
||||
deleted = 0
|
||||
delete_ids = [d["device_id"] for d in to_delete]
|
||||
|
||||
for i in range(0, len(delete_ids), BATCH_SIZE):
|
||||
batch = delete_ids[i : i + BATCH_SIZE]
|
||||
try:
|
||||
count = await delete_devices_batch(
|
||||
client, homeserver, headers, user_id, batch,
|
||||
)
|
||||
deleted += count
|
||||
logger.info(
|
||||
" Deleted batch %d-%d (%d devices)",
|
||||
i, i + len(batch), count,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
" Batch %d-%d failed: %d %s",
|
||||
i, i + len(batch), e.response.status_code, e.response.text,
|
||||
)
|
||||
|
||||
if i + BATCH_SIZE < len(delete_ids):
|
||||
await asyncio.sleep(BATCH_DELAY)
|
||||
|
||||
logger.info("Cleanup complete: deleted %d of %d devices", deleted, len(devices))
|
||||
return {"total": len(devices), "kept": len(to_keep), "deleted": deleted}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Clean up stale Matrix devices")
|
||||
parser.add_argument("--user", required=True, help="Matrix user ID (e.g. @admin:agiliton.eu)")
|
||||
parser.add_argument(
|
||||
"--homeserver",
|
||||
default=os.environ.get("MATRIX_HOMESERVER", "https://matrix.agiliton.eu"),
|
||||
help="Homeserver URL (use http://CONTAINER_IP:8008 when running on the matrix VM)",
|
||||
)
|
||||
parser.add_argument("--keep", type=int, default=1, help="Number of most recent devices to keep")
|
||||
parser.add_argument("--max-age-days", type=int, default=None, help="Only delete devices older than N days")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show what would be deleted without deleting")
|
||||
parser.add_argument("--skip", nargs="*", default=[], help="Device IDs to never delete")
|
||||
parser.add_argument("--auto", action="store_true", help="Auto mode: --max-age-days 30 --keep 3")
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(levelname)s %(message)s",
|
||||
)
|
||||
|
||||
if args.auto:
|
||||
if args.max_age_days is None:
|
||||
args.max_age_days = 30
|
||||
if args.keep == 1:
|
||||
args.keep = 3
|
||||
|
||||
result = asyncio.run(
|
||||
cleanup_devices(
|
||||
homeserver=args.homeserver,
|
||||
user_id=args.user,
|
||||
keep=args.keep,
|
||||
max_age_days=args.max_age_days,
|
||||
dry_run=args.dry_run,
|
||||
skip_device_ids=args.skip,
|
||||
)
|
||||
)
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result.get("deleted", 0) >= 0 else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
setup_ssl.sh
Normal file
21
setup_ssl.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# MAT-107: Generate self-signed SSL cert for memory-db and configure postgres
|
||||
set -euo pipefail
|
||||
|
||||
SSL_DIR="/opt/matrix-ai-agent/memory-db-ssl"
|
||||
mkdir -p "$SSL_DIR"
|
||||
|
||||
# Generate self-signed cert (valid 10 years)
|
||||
openssl req -new -x509 -days 3650 -nodes \
|
||||
-subj "/CN=memory-db" \
|
||||
-keyout "$SSL_DIR/server.key" \
|
||||
-out "$SSL_DIR/server.crt" \
|
||||
2>/dev/null
|
||||
|
||||
# Postgres requires specific permissions
|
||||
chmod 600 "$SSL_DIR/server.key"
|
||||
chmod 644 "$SSL_DIR/server.crt"
|
||||
# Postgres runs as uid 999 in the pgvector container
|
||||
chown 999:999 "$SSL_DIR/server.key" "$SSL_DIR/server.crt"
|
||||
|
||||
echo "SSL certs generated in $SSL_DIR"
|
||||
199
test_element_call.py
Normal file
199
test_element_call.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Playwright test: Element Call with matrix-ai-agent bot.
|
||||
|
||||
Usage:
|
||||
python3 test_element_call.py [--headless] [--no-e2ee-check]
|
||||
|
||||
Logs in as testbot-playwright, creates DM with bot, starts Element Call,
|
||||
uses fake microphone audio, monitors bot logs for VAD/speech events.
|
||||
"""
|
||||
import asyncio
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
# Test config
|
||||
ELEMENT_URL = "https://element.agiliton.eu"
|
||||
TEST_USER = "@testbot-playwright:agiliton.eu"
|
||||
TEST_USER_LOCAL = "testbot-playwright"
|
||||
TEST_PASSWORD = "TestP@ssw0rd-1771760269"
|
||||
BOT_USER = "@ai:agiliton.eu"
|
||||
HOMESERVER = "https://matrix.agiliton.eu"
|
||||
|
||||
|
||||
async def wait_for_bot_event(keyword: str, timeout: int = 60) -> bool:
|
||||
"""Poll bot container logs for a specific keyword."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
result = subprocess.run(
|
||||
["ssh", "root@matrix.agiliton.internal",
|
||||
"cd /opt/matrix-ai-agent && docker compose logs bot --tail=50 2>&1"],
|
||||
capture_output=True, text=True, timeout=15
|
||||
)
|
||||
if keyword in result.stdout:
|
||||
return True
|
||||
await asyncio.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
async def run_test(headless: bool = True):
|
||||
async with async_playwright() as p:
|
||||
# Launch with fake audio device so VAD can trigger
|
||||
browser = await p.chromium.launch(
|
||||
headless=headless,
|
||||
args=[
|
||||
"--use-fake-ui-for-media-stream",
|
||||
"--use-fake-device-for-media-stream",
|
||||
"--allow-running-insecure-content",
|
||||
"--disable-web-security",
|
||||
"--no-sandbox",
|
||||
]
|
||||
)
|
||||
context = await browser.new_context(
|
||||
permissions=["microphone", "camera"],
|
||||
# Grant media permissions automatically
|
||||
)
|
||||
page = await context.new_page()
|
||||
|
||||
# Capture console logs
|
||||
page.on("console", lambda msg: print(f" [browser] {msg.type}: {msg.text}") if msg.type in ("error", "warn") else None)
|
||||
|
||||
print(f"[1] Navigating to {ELEMENT_URL}...")
|
||||
await page.goto(ELEMENT_URL, wait_until="networkidle", timeout=30000)
|
||||
await page.screenshot(path="/tmp/element-01-loaded.png")
|
||||
|
||||
# Handle "Continue" button if shown (welcome screen)
|
||||
try:
|
||||
await page.click("text=Continue", timeout=3000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("[2] Logging in...")
|
||||
# Click Sign In button if present
|
||||
try:
|
||||
await page.click("text=Sign in", timeout=5000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Wait for username field
|
||||
await page.wait_for_selector("input[type='text'], input[id='mx_LoginForm_username']", timeout=15000)
|
||||
await page.screenshot(path="/tmp/element-02-login.png")
|
||||
|
||||
# Fill username
|
||||
username_input = page.locator("input[type='text'], input[id='mx_LoginForm_username']").first
|
||||
await username_input.fill(TEST_USER_LOCAL)
|
||||
|
||||
# Fill password
|
||||
password_input = page.locator("input[type='password']").first
|
||||
await password_input.fill(TEST_PASSWORD)
|
||||
|
||||
# Submit
|
||||
await page.keyboard.press("Enter")
|
||||
await page.wait_for_timeout(5000)
|
||||
await page.screenshot(path="/tmp/element-03-after-login.png")
|
||||
|
||||
# Handle "Use without" / skip verification prompts
|
||||
for skip_text in ["Use without", "Skip", "I'll verify later", "Continue"]:
|
||||
try:
|
||||
await page.click(f"text={skip_text}", timeout=2000)
|
||||
await page.wait_for_timeout(1000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await page.screenshot(path="/tmp/element-04-home.png")
|
||||
|
||||
print("[3] Creating DM with bot...")
|
||||
# Click new DM button
|
||||
try:
|
||||
# Try the compose / start DM button
|
||||
await page.click("[aria-label='Start chat'], [title='Start chat'], button:has-text('Start')", timeout=5000)
|
||||
except Exception:
|
||||
# Try the + button near People
|
||||
try:
|
||||
await page.click("[aria-label='Add room'], .mx_RoomList_headerButtons button", timeout=5000)
|
||||
except Exception:
|
||||
print(" Could not find DM button, trying navigation...")
|
||||
await page.goto(f"{ELEMENT_URL}/#/new", timeout=10000)
|
||||
|
||||
await page.wait_for_timeout(2000)
|
||||
await page.screenshot(path="/tmp/element-05-dm-dialog.png")
|
||||
|
||||
# Search for bot user
|
||||
try:
|
||||
dm_input = page.locator("input[type='text']").first
|
||||
await dm_input.fill(BOT_USER)
|
||||
await page.wait_for_timeout(2000)
|
||||
# Click on result
|
||||
await page.click(f"text={BOT_USER}", timeout=5000)
|
||||
await page.wait_for_timeout(1000)
|
||||
# Confirm DM
|
||||
await page.click("button:has-text('Go'), button:has-text('OK'), button:has-text('Direct Message')", timeout=5000)
|
||||
except Exception as e:
|
||||
print(f" DM creation error: {e}")
|
||||
|
||||
await page.wait_for_timeout(3000)
|
||||
await page.screenshot(path="/tmp/element-06-room.png")
|
||||
|
||||
print("[4] Looking for call button...")
|
||||
# Look for the video call button in the room header
|
||||
try:
|
||||
await page.click("[aria-label='Video call'], [title='Video call'], button.mx_LegacyCallButton", timeout=10000)
|
||||
print(" Clicked video call button")
|
||||
except Exception as e:
|
||||
print(f" Could not find call button: {e}")
|
||||
# Try text-based
|
||||
try:
|
||||
await page.click("text=Video call", timeout=5000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await page.wait_for_timeout(5000)
|
||||
await page.screenshot(path="/tmp/element-07-call-started.png")
|
||||
|
||||
print("[5] Waiting for bot to join (60s)...")
|
||||
# Monitor bot logs for connection
|
||||
bot_joined = await wait_for_bot_event("Connected", timeout=60)
|
||||
if bot_joined:
|
||||
print(" ✓ Bot joined the call!")
|
||||
else:
|
||||
print(" ✗ Bot did not join within 60s")
|
||||
|
||||
print("[6] Fake microphone is active — waiting for VAD events (30s)...")
|
||||
await page.wait_for_timeout(10000) # let call run for 10s
|
||||
await page.screenshot(path="/tmp/element-08-in-call.png")
|
||||
|
||||
vad_triggered = await wait_for_bot_event("VAD: user_state=", timeout=20)
|
||||
if vad_triggered:
|
||||
print(" ✓ VAD triggered! Audio pipeline works, E2EE decryption successful.")
|
||||
else:
|
||||
print(" ✗ VAD did not trigger — either E2EE blocks audio or pipeline issue")
|
||||
|
||||
speech_transcribed = await wait_for_bot_event("USER_SPEECH:", timeout=30)
|
||||
if speech_transcribed:
|
||||
print(" ✓ Speech transcribed! Full pipeline working.")
|
||||
else:
|
||||
print(" ✗ No speech transcription")
|
||||
|
||||
print("[7] Checking E2EE state in logs...")
|
||||
result = subprocess.run(
|
||||
["ssh", "root@matrix.agiliton.internal",
|
||||
"cd /opt/matrix-ai-agent && docker compose logs bot --tail=100 2>&1"],
|
||||
capture_output=True, text=True, timeout=15
|
||||
)
|
||||
for line in result.stdout.split("\n"):
|
||||
if any(kw in line for kw in ["E2EE_STATE", "VAD", "USER_SPEECH", "AGENT_SPEECH", "DEC_FAILED", "MISSING_KEY", "shared_key", "HKDF"]):
|
||||
print(f" LOG: {line.strip()}")
|
||||
|
||||
await page.wait_for_timeout(5000)
|
||||
await page.screenshot(path="/tmp/element-09-final.png")
|
||||
|
||||
print("\nScreenshots saved to /tmp/element-*.png")
|
||||
await browser.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--headless", action="store_true", help="Run headless")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(run_test(headless=args.headless))
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
267
tests/test_cron_brave_search.py
Normal file
267
tests/test_cron_brave_search.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for the Brave Search cron executor."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.brave_search import execute_brave_search
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job():
|
||||
return {
|
||||
"id": "j1",
|
||||
"name": "BMW Search",
|
||||
"jobType": "brave_search",
|
||||
"config": {"query": "BMW X3 damaged Cyprus", "maxResults": 5},
|
||||
"targetRoom": "!room:cars",
|
||||
"dedupKeys": ["https://old-result.com"],
|
||||
}
|
||||
|
||||
|
||||
class TestBraveSearchExecutor:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_without_api_key(self, job):
|
||||
with patch.dict(os.environ, {"BRAVE_API_KEY": ""}, clear=False):
|
||||
# Need to reload module to pick up empty env
|
||||
import importlib
|
||||
import cron.brave_search as bs
|
||||
original_key = bs.BRAVE_API_KEY
|
||||
bs.BRAVE_API_KEY = ""
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=AsyncMock())
|
||||
assert result["status"] == "error"
|
||||
assert "BRAVE_API_KEY" in result["error"]
|
||||
finally:
|
||||
bs.BRAVE_API_KEY = original_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_without_query(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Empty",
|
||||
"jobType": "brave_search",
|
||||
"config": {},
|
||||
"targetRoom": "!room:test",
|
||||
"dedupKeys": [],
|
||||
}
|
||||
import cron.brave_search as bs
|
||||
original_key = bs.BRAVE_API_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=AsyncMock())
|
||||
assert result["status"] == "error"
|
||||
assert "query" in result["error"].lower()
|
||||
finally:
|
||||
bs.BRAVE_API_KEY = original_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_results(self, job):
|
||||
"""Results with URLs already in dedupKeys should be filtered out."""
|
||||
import cron.brave_search as bs
|
||||
original_key = bs.BRAVE_API_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"web": {
|
||||
"results": [
|
||||
{"title": "Old Result", "url": "https://old-result.com", "description": "Already seen"},
|
||||
{"title": "New BMW", "url": "https://new-result.com", "description": "Fresh listing"},
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
send_text = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=send_text)
|
||||
finally:
|
||||
bs.BRAVE_API_KEY = original_key
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["newDedupKeys"] == ["https://new-result.com"]
|
||||
send_text.assert_called_once()
|
||||
# Message should contain only the new result
|
||||
msg = send_text.call_args[0][1]
|
||||
assert "New BMW" in msg
|
||||
assert "Old Result" not in msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_results_status(self, job):
|
||||
"""When API returns empty results, status should be no_results."""
|
||||
import cron.brave_search as bs
|
||||
original_key = bs.BRAVE_API_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"web": {"results": []}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=AsyncMock())
|
||||
finally:
|
||||
bs.BRAVE_API_KEY = original_key
|
||||
|
||||
assert result["status"] == "no_results"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_results_already_seen(self, job):
|
||||
"""When all results are already in dedupKeys, status should be no_results."""
|
||||
import cron.brave_search as bs
|
||||
original_key = bs.BRAVE_API_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"web": {
|
||||
"results": [
|
||||
{"title": "Old", "url": "https://old-result.com", "description": "Seen"},
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=AsyncMock())
|
||||
finally:
|
||||
bs.BRAVE_API_KEY = original_key
|
||||
|
||||
assert result["status"] == "no_results"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_filter_keeps_matching_results(self):
|
||||
"""LLM filter should only keep results that match criteria."""
|
||||
import cron.brave_search as bs
|
||||
orig_key, orig_url, orig_llm_key = bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
bs.LITELLM_URL = "http://llm:4000/v1"
|
||||
bs.LITELLM_KEY = "sk-test"
|
||||
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "BMW Search",
|
||||
"jobType": "brave_search",
|
||||
"config": {"query": "BMW X3 damaged", "maxResults": 5, "criteria": "Must be BMW X3, petrol, <=2019, damaged"},
|
||||
"targetRoom": "!room:cars",
|
||||
"dedupKeys": [],
|
||||
}
|
||||
|
||||
brave_resp = MagicMock()
|
||||
brave_resp.json.return_value = {"web": {"results": [
|
||||
{"title": "BMW X3 2018 Unfallwagen Benzin", "url": "https://a.com", "description": "Damaged"},
|
||||
{"title": "Toyota Corolla 2020", "url": "https://b.com", "description": "Not a BMW"},
|
||||
{"title": "BMW X3 2017 Diesel crash", "url": "https://c.com", "description": "Diesel"},
|
||||
]}}
|
||||
brave_resp.raise_for_status = MagicMock()
|
||||
|
||||
llm_resp = MagicMock()
|
||||
llm_resp.json.return_value = {"choices": [{"message": {"content": "[0]"}}]}
|
||||
llm_resp.raise_for_status = MagicMock()
|
||||
|
||||
send_text = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=brave_resp)
|
||||
mock_client.post = AsyncMock(return_value=llm_resp)
|
||||
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=send_text)
|
||||
finally:
|
||||
bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY = orig_key, orig_url, orig_llm_key
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["newDedupKeys"] == ["https://a.com"]
|
||||
msg = send_text.call_args[0][1]
|
||||
assert "Unfallwagen" in msg
|
||||
assert "Toyota" not in msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_filter_no_matches_returns_no_results(self):
|
||||
"""When LLM filter rejects all results, status should be no_results."""
|
||||
import cron.brave_search as bs
|
||||
orig_key, orig_url, orig_llm_key = bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
bs.LITELLM_URL = "http://llm:4000/v1"
|
||||
bs.LITELLM_KEY = "sk-test"
|
||||
|
||||
job = {
|
||||
"id": "j1", "name": "Search", "jobType": "brave_search",
|
||||
"config": {"query": "test", "criteria": "Must be exactly X"},
|
||||
"targetRoom": "!room:test", "dedupKeys": [],
|
||||
}
|
||||
|
||||
brave_resp = MagicMock()
|
||||
brave_resp.json.return_value = {"web": {"results": [{"title": "Nope", "url": "https://x.com", "description": "No"}]}}
|
||||
brave_resp.raise_for_status = MagicMock()
|
||||
|
||||
llm_resp = MagicMock()
|
||||
llm_resp.json.return_value = {"choices": [{"message": {"content": "[]"}}]}
|
||||
llm_resp.raise_for_status = MagicMock()
|
||||
|
||||
send_text = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=brave_resp)
|
||||
mock_client.post = AsyncMock(return_value=llm_resp)
|
||||
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=send_text)
|
||||
finally:
|
||||
bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY = orig_key, orig_url, orig_llm_key
|
||||
|
||||
assert result["status"] == "no_results"
|
||||
send_text.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_criteria_skips_llm_filter(self, job):
|
||||
"""Without criteria, results pass through without LLM call."""
|
||||
import cron.brave_search as bs
|
||||
orig_key = bs.BRAVE_API_KEY
|
||||
bs.BRAVE_API_KEY = "test-key"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"web": {"results": [{"title": "R", "url": "https://new.com", "description": "D"}]}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
send_text = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
try:
|
||||
result = await execute_brave_search(job=job, send_text=send_text)
|
||||
finally:
|
||||
bs.BRAVE_API_KEY = orig_key
|
||||
|
||||
assert result["status"] == "success"
|
||||
mock_client.post.assert_not_called()
|
||||
58
tests/test_cron_browser_executor.py
Normal file
58
tests/test_cron_browser_executor.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for the browser scrape executor."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.browser_executor import execute_browser_scrape
|
||||
|
||||
|
||||
class TestBrowserScrapeExecutor:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_without_profile(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "FB Scan",
|
||||
"config": {"url": "https://facebook.com/marketplace"},
|
||||
"targetRoom": "!room:test",
|
||||
"browserProfile": None,
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_browser_scrape(job=job, send_text=send_text)
|
||||
assert result["status"] == "error"
|
||||
assert "browser profile" in result["error"].lower()
|
||||
send_text.assert_called_once()
|
||||
msg = send_text.call_args[0][1]
|
||||
assert "matrixhost.eu/settings/automations" in msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_with_expired_profile(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "FB Scan",
|
||||
"config": {"url": "https://facebook.com/marketplace"},
|
||||
"targetRoom": "!room:test",
|
||||
"browserProfile": {"id": "b1", "status": "expired", "name": "facebook"},
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_browser_scrape(job=job, send_text=send_text)
|
||||
assert result["status"] == "error"
|
||||
assert "expired" in result["error"].lower()
|
||||
send_text.assert_called_once()
|
||||
msg = send_text.call_args[0][1]
|
||||
assert "re-record" in msg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_placeholder_with_active_profile(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "FB Scan",
|
||||
"config": {"url": "https://facebook.com/marketplace"},
|
||||
"targetRoom": "!room:test",
|
||||
"browserProfile": {"id": "b1", "status": "active", "name": "facebook"},
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_browser_scrape(job=job, send_text=send_text)
|
||||
# Currently a placeholder, should indicate not yet implemented
|
||||
assert result["status"] == "error"
|
||||
assert "not yet implemented" in result["error"].lower()
|
||||
48
tests/test_cron_executor.py
Normal file
48
tests/test_cron_executor.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for the cron executor dispatch."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.executor import execute_job
|
||||
|
||||
|
||||
class TestExecuteJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_job_type_returns_error(self):
|
||||
job = {"jobType": "nonexistent", "config": {}}
|
||||
result = await execute_job(
|
||||
job=job, send_text=AsyncMock(), matrix_client=None
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "Unknown job type" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_to_reminder(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Test Reminder",
|
||||
"jobType": "reminder",
|
||||
"config": {"message": "Don't forget!"},
|
||||
"targetRoom": "!room:test",
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_job(job=job, send_text=send_text, matrix_client=None)
|
||||
assert result["status"] == "success"
|
||||
send_text.assert_called_once()
|
||||
assert "Don't forget!" in send_text.call_args[0][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_to_browser_scrape_no_profile(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Scrape Test",
|
||||
"jobType": "browser_scrape",
|
||||
"config": {"url": "https://example.com"},
|
||||
"targetRoom": "!room:test",
|
||||
"browserProfile": None,
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_job(job=job, send_text=send_text, matrix_client=None)
|
||||
assert result["status"] == "error"
|
||||
assert "browser profile" in result["error"].lower()
|
||||
67
tests/test_cron_formatter.py
Normal file
67
tests/test_cron_formatter.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Tests for the cron result formatter."""
|
||||
|
||||
from cron.formatter import format_search_results, format_listings
|
||||
|
||||
|
||||
class TestFormatSearchResults:
|
||||
def test_single_result(self):
|
||||
results = [
|
||||
{"title": "BMW X3 2018", "url": "https://example.com/1", "description": "Unfallwagen"}
|
||||
]
|
||||
msg = format_search_results("BMW Scan", results)
|
||||
assert "BMW Scan" in msg
|
||||
assert "1 new result" in msg
|
||||
assert "BMW X3 2018" in msg
|
||||
assert "https://example.com/1" in msg
|
||||
assert "Unfallwagen" in msg
|
||||
assert "matrixhost.eu/settings/automations" in msg
|
||||
|
||||
def test_multiple_results(self):
|
||||
results = [
|
||||
{"title": "Result 1", "url": "https://a.com", "description": "Desc 1"},
|
||||
{"title": "Result 2", "url": "https://b.com", "description": "Desc 2"},
|
||||
{"title": "Result 3", "url": "https://c.com", "description": ""},
|
||||
]
|
||||
msg = format_search_results("Test Search", results)
|
||||
assert "3 new results" in msg
|
||||
assert "1." in msg
|
||||
assert "2." in msg
|
||||
assert "3." in msg
|
||||
|
||||
def test_result_without_description(self):
|
||||
results = [{"title": "No Desc", "url": "https://x.com"}]
|
||||
msg = format_search_results("Search", results)
|
||||
assert "No Desc" in msg
|
||||
# Should not have empty description line
|
||||
|
||||
|
||||
class TestFormatListings:
|
||||
def test_single_listing(self):
|
||||
listings = [
|
||||
{
|
||||
"title": "BMW X3 2.0i",
|
||||
"price": "\u20ac4,500",
|
||||
"location": "Limassol",
|
||||
"url": "https://fb.com/123",
|
||||
"age": "2h ago",
|
||||
}
|
||||
]
|
||||
msg = format_listings("Car Scan", listings)
|
||||
assert "Car Scan" in msg
|
||||
assert "1 new listing" in msg
|
||||
assert "BMW X3 2.0i" in msg
|
||||
assert "\u20ac4,500" in msg
|
||||
assert "Limassol" in msg
|
||||
assert "2h ago" in msg
|
||||
assert "https://fb.com/123" in msg
|
||||
|
||||
def test_listing_without_optional_fields(self):
|
||||
listings = [{"title": "Bare Listing"}]
|
||||
msg = format_listings("Scan", listings)
|
||||
assert "Bare Listing" in msg
|
||||
assert "1 new listing" in msg
|
||||
|
||||
def test_multiple_listings_plural(self):
|
||||
listings = [{"title": f"Item {i}"} for i in range(5)]
|
||||
msg = format_listings("Multi", listings)
|
||||
assert "5 new listings" in msg
|
||||
57
tests/test_cron_reminder.py
Normal file
57
tests/test_cron_reminder.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Tests for the reminder cron executor."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.reminder import execute_reminder
|
||||
|
||||
|
||||
class TestReminderExecutor:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_reminder_to_room(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Daily Check",
|
||||
"config": {"message": "Check your portfolio"},
|
||||
"targetRoom": "!room:finance",
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_reminder(job=job, send_text=send_text)
|
||||
|
||||
assert result["status"] == "success"
|
||||
send_text.assert_called_once()
|
||||
room_id, msg = send_text.call_args[0]
|
||||
assert room_id == "!room:finance"
|
||||
assert "Check your portfolio" in msg
|
||||
assert "Daily Check" in msg
|
||||
assert "\u23f0" in msg # alarm clock emoji
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_without_message(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Empty",
|
||||
"config": {},
|
||||
"targetRoom": "!room:test",
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_reminder(job=job, send_text=send_text)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "message" in result["error"].lower()
|
||||
send_text.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_message_returns_error(self):
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Empty",
|
||||
"config": {"message": ""},
|
||||
"targetRoom": "!room:test",
|
||||
}
|
||||
send_text = AsyncMock()
|
||||
result = await execute_reminder(job=job, send_text=send_text)
|
||||
|
||||
assert result["status"] == "error"
|
||||
send_text.assert_not_called()
|
||||
217
tests/test_cron_scheduler.py
Normal file
217
tests/test_cron_scheduler.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Tests for the cron scheduler module."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import zoneinfo
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import CronScheduler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler():
|
||||
send_text = AsyncMock()
|
||||
matrix_client = MagicMock()
|
||||
sched = CronScheduler(
|
||||
portal_url="https://matrixhost.eu",
|
||||
api_key="test-key",
|
||||
matrix_client=matrix_client,
|
||||
send_text_fn=send_text,
|
||||
)
|
||||
return sched
|
||||
|
||||
|
||||
class TestSecondsUntilNextRun:
|
||||
def test_daily_schedule_future_today(self, scheduler):
|
||||
tz = zoneinfo.ZoneInfo("Europe/Berlin")
|
||||
now = datetime.now(tz)
|
||||
# Set scheduleAt to 2 hours from now
|
||||
future_time = now + timedelta(hours=2)
|
||||
job = {
|
||||
"schedule": "daily",
|
||||
"scheduleAt": f"{future_time.hour:02d}:{future_time.minute:02d}",
|
||||
"timezone": "Europe/Berlin",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
assert 7000 < secs < 7300 # roughly 2 hours
|
||||
|
||||
def test_daily_schedule_past_today_goes_tomorrow(self, scheduler):
|
||||
tz = zoneinfo.ZoneInfo("Europe/Berlin")
|
||||
now = datetime.now(tz)
|
||||
# Set scheduleAt to 2 hours ago
|
||||
past_time = now - timedelta(hours=2)
|
||||
job = {
|
||||
"schedule": "daily",
|
||||
"scheduleAt": f"{past_time.hour:02d}:{past_time.minute:02d}",
|
||||
"timezone": "Europe/Berlin",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
# Should be ~22 hours from now
|
||||
assert 78000 < secs < 80000
|
||||
|
||||
def test_hourly_schedule(self, scheduler):
|
||||
job = {
|
||||
"schedule": "hourly",
|
||||
"scheduleAt": None,
|
||||
"timezone": "Europe/Berlin",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
# Should be between 0 and 3600
|
||||
assert 0 <= secs <= 3600
|
||||
|
||||
def test_weekdays_skips_weekend(self, scheduler):
|
||||
# Mock a Saturday
|
||||
job = {
|
||||
"schedule": "weekdays",
|
||||
"scheduleAt": "09:00",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
assert secs > 0
|
||||
|
||||
def test_weekly_schedule(self, scheduler):
|
||||
job = {
|
||||
"schedule": "weekly",
|
||||
"scheduleAt": "09:00",
|
||||
"timezone": "Europe/Berlin",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
# Should be between 0 and 7 days
|
||||
assert 0 < secs <= 7 * 86400
|
||||
|
||||
def test_default_timezone(self, scheduler):
|
||||
job = {
|
||||
"schedule": "daily",
|
||||
"scheduleAt": "23:59",
|
||||
"timezone": "Europe/Berlin",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
assert secs > 0
|
||||
|
||||
def test_different_timezone(self, scheduler):
|
||||
job = {
|
||||
"schedule": "daily",
|
||||
"scheduleAt": "09:00",
|
||||
"timezone": "America/New_York",
|
||||
}
|
||||
secs = scheduler._seconds_until_next_run(job)
|
||||
assert secs > 0
|
||||
|
||||
|
||||
class TestSyncJobs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_adds_new_jobs(self, scheduler):
|
||||
"""New jobs from the API should be registered as tasks."""
|
||||
jobs_response = {
|
||||
"jobs": [
|
||||
{
|
||||
"id": "j1",
|
||||
"name": "Test Job",
|
||||
"jobType": "brave_search",
|
||||
"schedule": "daily",
|
||||
"scheduleAt": "09:00",
|
||||
"timezone": "Europe/Berlin",
|
||||
"config": {"query": "test"},
|
||||
"targetRoom": "!room:test",
|
||||
"enabled": True,
|
||||
"updatedAt": "2026-03-16T00:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = jobs_response
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await scheduler._sync_jobs()
|
||||
|
||||
assert "j1" in scheduler._jobs
|
||||
assert "j1" in scheduler._tasks
|
||||
# Clean up the task
|
||||
scheduler._tasks["j1"].cancel()
|
||||
try:
|
||||
await scheduler._tasks["j1"]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_removes_deleted_jobs(self, scheduler):
|
||||
"""Jobs removed from the API should have their tasks cancelled."""
|
||||
# Pre-populate with a job
|
||||
mock_task = AsyncMock()
|
||||
mock_task.cancel = MagicMock()
|
||||
scheduler._jobs["old_job"] = {"id": "old_job"}
|
||||
scheduler._tasks["old_job"] = mock_task
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"jobs": []} # No jobs
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await scheduler._sync_jobs()
|
||||
|
||||
assert "old_job" not in scheduler._tasks
|
||||
assert "old_job" not in scheduler._jobs
|
||||
mock_task.cancel.assert_called_once()
|
||||
|
||||
|
||||
class TestRunOnce:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_once_reports_success(self, scheduler):
|
||||
"""Successful execution should report back to portal."""
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Test",
|
||||
"jobType": "reminder",
|
||||
"config": {"message": "Hello"},
|
||||
"targetRoom": "!room:test",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock()
|
||||
|
||||
await scheduler._run_once(job)
|
||||
|
||||
# send_text should have been called with the reminder
|
||||
scheduler.send_text.assert_called_once()
|
||||
call_args = scheduler.send_text.call_args
|
||||
assert call_args[0][0] == "!room:test"
|
||||
assert "Hello" in call_args[0][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_once_reports_error(self, scheduler):
|
||||
"""Failed execution should report error back to portal."""
|
||||
job = {
|
||||
"id": "j1",
|
||||
"name": "Test",
|
||||
"jobType": "brave_search",
|
||||
"config": {}, # Missing query = error
|
||||
"targetRoom": "!room:test",
|
||||
"dedupKeys": [],
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock()
|
||||
|
||||
await scheduler._run_once(job)
|
||||
|
||||
# Should not have sent a message to the room (error in executor)
|
||||
# But should have reported back
|
||||
# The report happens via httpx post
|
||||
52
tests/test_device_trust.py
Normal file
52
tests/test_device_trust.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from device_trust import CrossSignedOnlyPolicy
|
||||
|
||||
|
||||
class TestCrossSignedOnlyPolicy:
|
||||
def setup_method(self):
|
||||
self.policy = CrossSignedOnlyPolicy()
|
||||
|
||||
def _make_device(self, device_id, user_id, extra_sig_keys=None):
|
||||
device = Mock()
|
||||
device.device_id = device_id
|
||||
sigs = {f"ed25519:{device_id}": "self_sig"}
|
||||
if extra_sig_keys:
|
||||
for k, v in extra_sig_keys.items():
|
||||
sigs[k] = v
|
||||
device.signatures = {user_id: sigs}
|
||||
return device
|
||||
|
||||
def test_should_trust_cross_signed(self):
|
||||
device = self._make_device(
|
||||
"DEV1", "@alice:example.com",
|
||||
extra_sig_keys={"ed25519:MASTER_KEY": "cross_sig"},
|
||||
)
|
||||
assert self.policy.should_trust("@alice:example.com", device) is True
|
||||
|
||||
def test_should_not_trust_self_signed_only(self):
|
||||
device = self._make_device("DEV1", "@alice:example.com")
|
||||
assert self.policy.should_trust("@alice:example.com", device) is False
|
||||
|
||||
def test_should_not_trust_no_signatures(self):
|
||||
device = Mock()
|
||||
device.device_id = "DEV1"
|
||||
device.signatures = None
|
||||
assert self.policy.should_trust("@alice:example.com", device) is False
|
||||
|
||||
def test_should_not_trust_empty_user_sigs(self):
|
||||
device = Mock()
|
||||
device.device_id = "DEV1"
|
||||
device.signatures = {"@alice:example.com": {}}
|
||||
assert self.policy.should_trust("@alice:example.com", device) is False
|
||||
|
||||
def test_should_not_trust_missing_user_in_sigs(self):
|
||||
device = Mock()
|
||||
device.device_id = "DEV1"
|
||||
device.signatures = {"@bob:example.com": {"ed25519:OTHER": "sig"}}
|
||||
assert self.policy.should_trust("@alice:example.com", device) is False
|
||||
|
||||
def test_should_not_trust_no_signatures_attr(self):
|
||||
device = Mock(spec=[])
|
||||
device.device_id = "DEV1"
|
||||
assert self.policy.should_trust("@alice:example.com", device) is False
|
||||
130
tests/test_e2ee_send.py
Normal file
130
tests/test_e2ee_send.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from nio.exceptions import OlmUnverifiedDeviceError
|
||||
|
||||
from device_trust import CrossSignedOnlyPolicy
|
||||
|
||||
|
||||
class TestSendTextErrorHandling:
|
||||
"""Test _send_text resilience against E2EE errors."""
|
||||
|
||||
def _make_bot(self, room_send_side_effect=None):
|
||||
"""Create a minimal bot-like object with _send_text method."""
|
||||
# Import the actual _send_text from bot module would pull too many deps,
|
||||
# so we replicate the patched logic here for unit testing.
|
||||
bot = Mock()
|
||||
bot.client = Mock()
|
||||
bot.client.room_send = AsyncMock(side_effect=room_send_side_effect)
|
||||
bot._md_to_html = Mock(return_value="<p>test</p>")
|
||||
return bot
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_success(self):
|
||||
bot = self._make_bot()
|
||||
# Inline the method logic to test it
|
||||
await self._call_send_text(bot, "!room:ex.com", "hello")
|
||||
bot.client.room_send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_olm_unverified_does_not_crash(self, caplog):
|
||||
bot = self._make_bot(
|
||||
room_send_side_effect=OlmUnverifiedDeviceError("Device XYZABC not verified")
|
||||
)
|
||||
with caplog.at_level(logging.ERROR):
|
||||
await self._call_send_text(bot, "!room:ex.com", "hello")
|
||||
assert "unverified device" in caplog.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_generic_error_does_not_crash(self, caplog):
|
||||
bot = self._make_bot(room_send_side_effect=Exception("network timeout"))
|
||||
with caplog.at_level(logging.ERROR):
|
||||
await self._call_send_text(bot, "!room:ex.com", "hello")
|
||||
assert "Send failed" in caplog.text
|
||||
|
||||
async def _call_send_text(self, bot, room_id, text):
|
||||
"""Replicate _send_text logic matching the patched bot.py."""
|
||||
logger = logging.getLogger("test_e2ee")
|
||||
try:
|
||||
await bot.client.room_send(
|
||||
room_id,
|
||||
message_type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": text,
|
||||
"format": "org.matrix.custom.html",
|
||||
"formatted_body": bot._md_to_html(text),
|
||||
},
|
||||
)
|
||||
except OlmUnverifiedDeviceError as e:
|
||||
logger.error("E2EE send failed in room %s: unverified device — %s", room_id, e)
|
||||
except Exception as e:
|
||||
logger.error("Send failed in room %s: %s", room_id, e)
|
||||
|
||||
|
||||
class TestOnSyncDeviceVerification:
|
||||
"""Test on_sync blacklists unverified devices instead of skipping."""
|
||||
|
||||
def _make_device(self, device_id, cross_signed=False):
|
||||
device = Mock()
|
||||
device.device_id = device_id
|
||||
device.verified = False
|
||||
if cross_signed:
|
||||
device.signatures = {
|
||||
"@user:ex.com": {
|
||||
f"ed25519:{device_id}": "self",
|
||||
"ed25519:MASTER": "cross",
|
||||
}
|
||||
}
|
||||
else:
|
||||
device.signatures = {
|
||||
"@user:ex.com": {f"ed25519:{device_id}": "self"}
|
||||
}
|
||||
return device
|
||||
|
||||
def test_blacklists_unverified_device(self):
|
||||
policy = CrossSignedOnlyPolicy()
|
||||
client = Mock()
|
||||
device = self._make_device("BADDEV", cross_signed=False)
|
||||
|
||||
# Simulate on_sync logic
|
||||
if not device.verified:
|
||||
if policy.should_trust("@user:ex.com", device):
|
||||
client.verify_device(device)
|
||||
else:
|
||||
client.blacklist_device(device)
|
||||
|
||||
client.blacklist_device.assert_called_once_with(device)
|
||||
client.verify_device.assert_not_called()
|
||||
|
||||
def test_verifies_cross_signed_device(self):
|
||||
policy = CrossSignedOnlyPolicy()
|
||||
client = Mock()
|
||||
device = self._make_device("GOODDEV", cross_signed=True)
|
||||
|
||||
if not device.verified:
|
||||
if policy.should_trust("@user:ex.com", device):
|
||||
client.verify_device(device)
|
||||
else:
|
||||
client.blacklist_device(device)
|
||||
|
||||
client.verify_device.assert_called_once_with(device)
|
||||
client.blacklist_device.assert_not_called()
|
||||
|
||||
def test_mixed_devices(self):
|
||||
policy = CrossSignedOnlyPolicy()
|
||||
client = Mock()
|
||||
good = self._make_device("GOOD", cross_signed=True)
|
||||
bad = self._make_device("BAD", cross_signed=False)
|
||||
|
||||
for device in [good, bad]:
|
||||
if not device.verified:
|
||||
if policy.should_trust("@user:ex.com", device):
|
||||
client.verify_device(device)
|
||||
else:
|
||||
client.blacklist_device(device)
|
||||
|
||||
client.verify_device.assert_called_once_with(good)
|
||||
client.blacklist_device.assert_called_once_with(bad)
|
||||
41
tests/test_needs_query_rewrite.py
Normal file
41
tests/test_needs_query_rewrite.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Heuristic gate for `_rewrite_query` (bot.py). Skips the LLM round-trip when
|
||||
the message has no pronouns or deictic references that would need context."""
|
||||
|
||||
from bot import Bot
|
||||
|
||||
|
||||
def _needs(msg: str) -> bool:
|
||||
return Bot._needs_query_rewrite(msg)
|
||||
|
||||
|
||||
def test_short_message_skipped():
|
||||
assert _needs("hi") is False
|
||||
assert _needs("ok") is False
|
||||
|
||||
|
||||
def test_self_contained_no_pronouns_skipped():
|
||||
assert _needs("What is the capital of France?") is False
|
||||
assert _needs("Summarize the Q3 earnings report") is False
|
||||
assert _needs("Wie ist das Wetter in Berlin morgen") is False
|
||||
|
||||
|
||||
def test_english_pronouns_trigger():
|
||||
assert _needs("What does it mean?") is True
|
||||
assert _needs("Can you fix that?") is True
|
||||
assert _needs("Tell me more about them") is True
|
||||
|
||||
|
||||
def test_german_pronouns_trigger():
|
||||
assert _needs("Was bedeutet das?") is True
|
||||
assert _needs("Kannst du es noch einmal erklären") is True
|
||||
assert _needs("Wer sind sie?") is True
|
||||
|
||||
|
||||
def test_french_pronouns_trigger():
|
||||
assert _needs("Qu'est-ce que ça veut dire?") is True
|
||||
assert _needs("Parle-moi de lui") is True
|
||||
|
||||
|
||||
def test_empty_or_whitespace():
|
||||
assert _needs("") is False
|
||||
assert _needs(" ") is False
|
||||
Reference in New Issue
Block a user