Track running job IDs to avoid creating duplicate Skyvern tasks when the pending check runs faster than the task completes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
308 lines
12 KiB
Python
308 lines
12 KiB
Python
"""Cron job scheduler that syncs with matrixhost-web API and executes jobs."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
import httpx
|
|
|
|
from .executor import execute_job
|
|
from pipelines import PipelineEngine, PipelineStateManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SYNC_INTERVAL = 300 # 5 minutes — full job reconciliation
|
|
PENDING_CHECK_INTERVAL = 15 # 15 seconds — fast check for manual triggers
|
|
|
|
|
|
class CronScheduler:
|
|
"""Fetches enabled cron jobs from the matrixhost portal and runs them on schedule."""
|
|
|
|
def __init__(self, portal_url: str, api_key: str, matrix_client, send_text_fn,
|
|
llm_client=None, default_model: str = "claude-haiku",
|
|
escalation_model: str = "claude-sonnet"):
|
|
self.portal_url = portal_url.rstrip("/")
|
|
self.api_key = api_key
|
|
self.matrix_client = matrix_client
|
|
self.send_text = send_text_fn
|
|
self._jobs: dict[str, dict] = {} # id -> job data
|
|
self._tasks: dict[str, asyncio.Task] = {} # id -> scheduler task
|
|
self._running = False
|
|
|
|
# Pipeline engine
|
|
self._pipeline_state = PipelineStateManager(portal_url, api_key)
|
|
self.pipeline_engine = PipelineEngine(
|
|
state=self._pipeline_state,
|
|
send_text=send_text_fn,
|
|
matrix_client=matrix_client,
|
|
llm_client=llm_client,
|
|
default_model=default_model,
|
|
escalation_model=escalation_model,
|
|
)
|
|
self._pipelines: dict[str, dict] = {} # id -> pipeline data
|
|
self._pipeline_tasks: dict[str, asyncio.Task] = {} # id -> scheduler task
|
|
self._running_jobs: set[str] = set() # job IDs currently executing
|
|
|
|
async def start(self):
|
|
"""Start the scheduler background loops."""
|
|
self._running = True
|
|
logger.info("Cron scheduler starting")
|
|
await asyncio.sleep(15) # wait for bot to stabilize
|
|
# Run full sync + fast pending check in parallel
|
|
await asyncio.gather(
|
|
self._full_sync_loop(),
|
|
self._pending_check_loop(),
|
|
self._pipeline_sync_loop(),
|
|
self._pipeline_pending_check_loop(),
|
|
)
|
|
|
|
async def _full_sync_loop(self):
|
|
"""Full job reconciliation every 5 minutes."""
|
|
while self._running:
|
|
try:
|
|
await self._sync_jobs()
|
|
except Exception:
|
|
logger.warning("Cron job sync failed", exc_info=True)
|
|
await asyncio.sleep(SYNC_INTERVAL)
|
|
|
|
async def _pending_check_loop(self):
|
|
"""Fast poll for manual triggers every 15 seconds."""
|
|
while self._running:
|
|
try:
|
|
await self._check_pending()
|
|
except Exception:
|
|
logger.debug("Pending check failed", exc_info=True)
|
|
await asyncio.sleep(PENDING_CHECK_INTERVAL)
|
|
|
|
async def _check_pending(self):
|
|
"""Quick check for jobs with lastStatus='pending' and run them."""
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(
|
|
f"{self.portal_url}/api/cron/jobs/active",
|
|
headers={"x-api-key": self.api_key},
|
|
)
|
|
if resp.status_code != 200:
|
|
return
|
|
data = resp.json()
|
|
|
|
for job in data.get("jobs", []):
|
|
if job.get("lastStatus") == "pending" and job["id"] not in self._running_jobs:
|
|
logger.info("Pending trigger: %s", job["name"])
|
|
self._running_jobs.add(job["id"])
|
|
asyncio.create_task(self._run_once(job))
|
|
|
|
async def _pipeline_sync_loop(self):
|
|
"""Full pipeline reconciliation every 5 minutes."""
|
|
while self._running:
|
|
try:
|
|
await self._sync_pipelines()
|
|
except Exception:
|
|
logger.warning("Pipeline sync failed", exc_info=True)
|
|
await asyncio.sleep(SYNC_INTERVAL)
|
|
|
|
async def _pipeline_pending_check_loop(self):
|
|
"""Fast poll for manually triggered pipelines every 15 seconds."""
|
|
while self._running:
|
|
try:
|
|
await self._check_pending_pipelines()
|
|
except Exception:
|
|
logger.debug("Pipeline pending check failed", exc_info=True)
|
|
await asyncio.sleep(PENDING_CHECK_INTERVAL)
|
|
|
|
async def _sync_pipelines(self):
|
|
"""Fetch active pipelines from portal and reconcile."""
|
|
pipelines = await self._pipeline_state.fetch_active_pipelines()
|
|
remote = {p["id"]: p for p in pipelines}
|
|
|
|
# Remove pipelines no longer active
|
|
for pid in list(self._pipeline_tasks):
|
|
if pid not in remote:
|
|
logger.info("Removing pipeline %s (no longer active)", pid)
|
|
self._pipeline_tasks[pid].cancel()
|
|
del self._pipeline_tasks[pid]
|
|
self._pipelines.pop(pid, None)
|
|
|
|
# Add/update cron-triggered pipelines
|
|
for pid, pipeline in remote.items():
|
|
existing = self._pipelines.get(pid)
|
|
if existing and existing.get("updatedAt") == pipeline.get("updatedAt"):
|
|
continue
|
|
|
|
if pid in self._pipeline_tasks:
|
|
self._pipeline_tasks[pid].cancel()
|
|
|
|
self._pipelines[pid] = pipeline
|
|
|
|
if pipeline.get("triggerType") == "cron":
|
|
self._pipeline_tasks[pid] = asyncio.create_task(
|
|
self._pipeline_cron_loop(pipeline), name=f"pipeline-{pid}"
|
|
)
|
|
logger.info("Scheduled pipeline: %s (%s @ %s)",
|
|
pipeline["name"], pipeline.get("schedule", ""), pipeline.get("scheduleAt", ""))
|
|
|
|
async def _check_pending_pipelines(self):
|
|
"""Check for pipelines with lastStatus='pending' and run them."""
|
|
pipelines = await self._pipeline_state.fetch_active_pipelines()
|
|
for pipeline in pipelines:
|
|
if pipeline.get("lastStatus") == "pending":
|
|
logger.info("Pending pipeline trigger: %s", pipeline["name"])
|
|
asyncio.create_task(self.pipeline_engine.run(pipeline))
|
|
|
|
async def _pipeline_cron_loop(self, pipeline: dict):
|
|
"""Run a pipeline on its cron schedule."""
|
|
try:
|
|
while True:
|
|
sleep_secs = self._seconds_until_next_run(pipeline)
|
|
if sleep_secs > 0:
|
|
await asyncio.sleep(sleep_secs)
|
|
await self.pipeline_engine.run(pipeline)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
def get_file_upload_pipelines(self) -> list[dict]:
|
|
"""Return all active file_upload-triggered pipelines."""
|
|
return [p for p in self._pipelines.values() if p.get("triggerType") == "file_upload"]
|
|
|
|
async def stop(self):
|
|
self._running = False
|
|
for task in self._tasks.values():
|
|
task.cancel()
|
|
self._tasks.clear()
|
|
for task in self._pipeline_tasks.values():
|
|
task.cancel()
|
|
self._pipeline_tasks.clear()
|
|
|
|
async def _sync_jobs(self):
|
|
"""Fetch active jobs from portal and reconcile with running tasks."""
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
resp = await client.get(
|
|
f"{self.portal_url}/api/cron/jobs/active",
|
|
headers={"x-api-key": self.api_key},
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
remote_jobs = {j["id"]: j for j in data.get("jobs", [])}
|
|
|
|
# Remove jobs that are no longer active
|
|
for job_id in list(self._tasks):
|
|
if job_id not in remote_jobs:
|
|
logger.info("Removing cron job %s (no longer active)", job_id)
|
|
self._tasks[job_id].cancel()
|
|
del self._tasks[job_id]
|
|
self._jobs.pop(job_id, None)
|
|
|
|
# Add/update jobs
|
|
for job_id, job in remote_jobs.items():
|
|
existing = self._jobs.get(job_id)
|
|
if existing and existing.get("updatedAt") == job.get("updatedAt"):
|
|
continue # unchanged
|
|
|
|
# Cancel old task if updating
|
|
if job_id in self._tasks:
|
|
self._tasks[job_id].cancel()
|
|
|
|
self._jobs[job_id] = job
|
|
self._tasks[job_id] = asyncio.create_task(
|
|
self._job_loop(job), name=f"cron-{job_id}"
|
|
)
|
|
logger.info("Scheduled cron job: %s (%s @ %s %s)",
|
|
job["name"], job["schedule"], job.get("scheduleAt", ""), job["timezone"])
|
|
|
|
async def _job_loop(self, job: dict):
|
|
"""Run a job on its schedule forever."""
|
|
try:
|
|
while True:
|
|
sleep_secs = self._seconds_until_next_run(job)
|
|
if sleep_secs > 0:
|
|
await asyncio.sleep(sleep_secs)
|
|
await self._run_once(job)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _run_once(self, job: dict):
|
|
"""Execute a single job run and report results back."""
|
|
job_id = job["id"]
|
|
logger.info("Running cron job: %s (%s)", job["name"], job["jobType"])
|
|
try:
|
|
result = await execute_job(
|
|
job=job,
|
|
send_text=self.send_text,
|
|
matrix_client=self.matrix_client,
|
|
)
|
|
await self._report_result(job_id, result)
|
|
except Exception as exc:
|
|
logger.error("Cron job %s failed: %s", job["name"], exc, exc_info=True)
|
|
await self._report_result(job_id, {
|
|
"status": "error",
|
|
"error": str(exc),
|
|
})
|
|
finally:
|
|
self._running_jobs.discard(job_id)
|
|
|
|
async def _report_result(self, job_id: str, result: dict):
|
|
"""Report job execution result back to the portal."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
await client.post(
|
|
f"{self.portal_url}/api/cron/jobs/{job_id}/result",
|
|
headers={"x-api-key": self.api_key},
|
|
json=result,
|
|
)
|
|
except Exception:
|
|
logger.warning("Failed to report cron result for %s", job_id, exc_info=True)
|
|
|
|
def _seconds_until_next_run(self, job: dict) -> float:
|
|
"""Calculate seconds until the next scheduled run."""
|
|
import zoneinfo
|
|
|
|
schedule = job["schedule"]
|
|
schedule_at = job.get("scheduleAt", "09:00") or "09:00"
|
|
tz = zoneinfo.ZoneInfo(job.get("timezone", "Europe/Berlin"))
|
|
now = datetime.now(tz)
|
|
|
|
hour, minute = (int(x) for x in schedule_at.split(":"))
|
|
|
|
if schedule == "hourly":
|
|
# Run at the top of every hour
|
|
next_run = now.replace(minute=0, second=0, microsecond=0)
|
|
if next_run <= now:
|
|
next_run = next_run.replace(hour=now.hour + 1)
|
|
return (next_run - now).total_seconds()
|
|
|
|
if schedule == "daily":
|
|
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
|
if next_run <= now:
|
|
from datetime import timedelta
|
|
next_run += timedelta(days=1)
|
|
return (next_run - now).total_seconds()
|
|
|
|
if schedule == "weekly":
|
|
# Monday = 0
|
|
from datetime import timedelta
|
|
days_ahead = (0 - now.weekday()) % 7 or 7
|
|
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
|
if now.weekday() == 0 and next_run > now:
|
|
days_ahead = 0
|
|
next_run += timedelta(days=days_ahead)
|
|
if next_run <= now:
|
|
next_run += timedelta(days=7)
|
|
return (next_run - now).total_seconds()
|
|
|
|
if schedule == "weekdays":
|
|
from datetime import timedelta
|
|
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
|
if next_run <= now:
|
|
next_run += timedelta(days=1)
|
|
# Skip weekends
|
|
while next_run.weekday() >= 5:
|
|
next_run += timedelta(days=1)
|
|
return (next_run - now).total_seconds()
|
|
|
|
# Default: daily
|
|
from datetime import timedelta
|
|
next_run = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
|
if next_run <= now:
|
|
next_run += timedelta(days=1)
|
|
return (next_run - now).total_seconds()
|