diff --git a/cron/brave_search.py b/cron/brave_search.py index f8ba0f3..c92875f 100644 --- a/cron/brave_search.py +++ b/cron/brave_search.py @@ -1,5 +1,6 @@ -"""Brave Search executor for cron jobs.""" +"""Brave Search executor for cron jobs with optional LLM filtering.""" +import json import logging import os @@ -10,15 +11,80 @@ from .formatter import format_search_results logger = logging.getLogger(__name__) BRAVE_API_KEY = os.environ.get("BRAVE_API_KEY", "") +LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "") +LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "") +FILTER_MODEL = os.environ.get("BASE_MODEL", "claude-haiku") + +FILTER_SYSTEM_PROMPT = """You are a search result filter. Given a list of search results and filtering criteria, evaluate each result and return ONLY the ones that match the criteria. + +Return a JSON array of indices (0-based) of results that match. If none match, return an empty array []. +Only return the JSON array, nothing else.""" + + +async def _llm_filter(results: list[dict], criteria: str) -> list[dict]: + """Use LLM to filter search results against user-defined criteria.""" + if not LITELLM_URL or not LITELLM_KEY: + logger.warning("LLM not configured, skipping filter") + return results + + # Build a concise representation of results for the LLM + result_descriptions = [] + for i, r in enumerate(results): + title = r.get("title", "") + desc = r.get("description", "") + url = r.get("url", "") + result_descriptions.append(f"[{i}] {title} — {desc} ({url})") + + user_msg = ( + f"**Criteria:** {criteria}\n\n" + f"**Results:**\n" + "\n".join(result_descriptions) + ) + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{LITELLM_URL}/chat/completions", + headers={"Authorization": f"Bearer {LITELLM_KEY}"}, + json={ + "model": FILTER_MODEL, + "messages": [ + {"role": "system", "content": FILTER_SYSTEM_PROMPT}, + {"role": "user", "content": user_msg}, + ], + "temperature": 0, + "max_tokens": 200, + }, + ) + resp.raise_for_status() + data = resp.json() + + reply = data["choices"][0]["message"]["content"].strip() + # Parse the JSON array of indices + indices = json.loads(reply) + if not isinstance(indices, list): + logger.warning("LLM filter returned non-list: %s", reply) + return results + + filtered = [results[i] for i in indices if 0 <= i < len(results)] + logger.info( + "LLM filter: %d/%d results matched criteria", + len(filtered), len(results), + ) + return filtered + + except Exception as exc: + logger.warning("LLM filter failed, returning all results: %s", exc) + return results async def execute_brave_search(job: dict, send_text, **_kwargs) -> dict: - """Run a Brave Search query, dedup against known keys, post new results to Matrix.""" + """Run a Brave Search query, dedup, optionally LLM-filter, post to Matrix.""" if not BRAVE_API_KEY: return {"status": "error", "error": "BRAVE_API_KEY not configured"} config = job.get("config", {}) query = config.get("query", "") + criteria = config.get("criteria", "") max_results = config.get("maxResults", 10) target_room = job["targetRoom"] dedup_keys = set(job.get("dedupKeys", [])) @@ -49,6 +115,12 @@ async def execute_brave_search(job: dict, send_text, **_kwargs) -> dict: if not new_results: return {"status": "no_results"} + # LLM filter if criteria provided + if criteria: + new_results = await _llm_filter(new_results, criteria) + if not new_results: + return {"status": "no_results"} + msg = format_search_results(job["name"], new_results) await send_text(target_room, msg) diff --git a/tests/test_cron_brave_search.py b/tests/test_cron_brave_search.py index a9d8d5c..99d9ae1 100644 --- a/tests/test_cron_brave_search.py +++ b/tests/test_cron_brave_search.py @@ -148,3 +148,120 @@ class TestBraveSearchExecutor: bs.BRAVE_API_KEY = original_key assert result["status"] == "no_results" + + @pytest.mark.asyncio + async def test_llm_filter_keeps_matching_results(self): + """LLM filter should only keep results that match criteria.""" + import cron.brave_search as bs + orig_key, orig_url, orig_llm_key = bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY + bs.BRAVE_API_KEY = "test-key" + bs.LITELLM_URL = "http://llm:4000/v1" + bs.LITELLM_KEY = "sk-test" + + job = { + "id": "j1", + "name": "BMW Search", + "jobType": "brave_search", + "config": {"query": "BMW X3 damaged", "maxResults": 5, "criteria": "Must be BMW X3, petrol, <=2019, damaged"}, + "targetRoom": "!room:cars", + "dedupKeys": [], + } + + brave_resp = MagicMock() + brave_resp.json.return_value = {"web": {"results": [ + {"title": "BMW X3 2018 Unfallwagen Benzin", "url": "https://a.com", "description": "Damaged"}, + {"title": "Toyota Corolla 2020", "url": "https://b.com", "description": "Not a BMW"}, + {"title": "BMW X3 2017 Diesel crash", "url": "https://c.com", "description": "Diesel"}, + ]}} + brave_resp.raise_for_status = MagicMock() + + llm_resp = MagicMock() + llm_resp.json.return_value = {"choices": [{"message": {"content": "[0]"}}]} + llm_resp.raise_for_status = MagicMock() + + send_text = AsyncMock() + + with patch("httpx.AsyncClient") as mock_cls: + mock_client = AsyncMock() + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=brave_resp) + mock_client.post = AsyncMock(return_value=llm_resp) + + try: + result = await execute_brave_search(job=job, send_text=send_text) + finally: + bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY = orig_key, orig_url, orig_llm_key + + assert result["status"] == "success" + assert result["newDedupKeys"] == ["https://a.com"] + msg = send_text.call_args[0][1] + assert "Unfallwagen" in msg + assert "Toyota" not in msg + + @pytest.mark.asyncio + async def test_llm_filter_no_matches_returns_no_results(self): + """When LLM filter rejects all results, status should be no_results.""" + import cron.brave_search as bs + orig_key, orig_url, orig_llm_key = bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY + bs.BRAVE_API_KEY = "test-key" + bs.LITELLM_URL = "http://llm:4000/v1" + bs.LITELLM_KEY = "sk-test" + + job = { + "id": "j1", "name": "Search", "jobType": "brave_search", + "config": {"query": "test", "criteria": "Must be exactly X"}, + "targetRoom": "!room:test", "dedupKeys": [], + } + + brave_resp = MagicMock() + brave_resp.json.return_value = {"web": {"results": [{"title": "Nope", "url": "https://x.com", "description": "No"}]}} + brave_resp.raise_for_status = MagicMock() + + llm_resp = MagicMock() + llm_resp.json.return_value = {"choices": [{"message": {"content": "[]"}}]} + llm_resp.raise_for_status = MagicMock() + + send_text = AsyncMock() + + with patch("httpx.AsyncClient") as mock_cls: + mock_client = AsyncMock() + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=brave_resp) + mock_client.post = AsyncMock(return_value=llm_resp) + + try: + result = await execute_brave_search(job=job, send_text=send_text) + finally: + bs.BRAVE_API_KEY, bs.LITELLM_URL, bs.LITELLM_KEY = orig_key, orig_url, orig_llm_key + + assert result["status"] == "no_results" + send_text.assert_not_called() + + @pytest.mark.asyncio + async def test_no_criteria_skips_llm_filter(self, job): + """Without criteria, results pass through without LLM call.""" + import cron.brave_search as bs + orig_key = bs.BRAVE_API_KEY + bs.BRAVE_API_KEY = "test-key" + + mock_response = MagicMock() + mock_response.json.return_value = {"web": {"results": [{"title": "R", "url": "https://new.com", "description": "D"}]}} + mock_response.raise_for_status = MagicMock() + + send_text = AsyncMock() + + with patch("httpx.AsyncClient") as mock_cls: + mock_client = AsyncMock() + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + + try: + result = await execute_brave_search(job=job, send_text=send_text) + finally: + bs.BRAVE_API_KEY = orig_key + + assert result["status"] == "success" + mock_client.post.assert_not_called()