security: enforce per-user data isolation in memory service
- Make user_id required on all request models with field validators - Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak) - Add bearer token auth on all endpoints except /health - Add composite index on (user_id, room_id) for conversation_chunks - Bot: guard query_chunks with sender check, pass room_id, send auth token - Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
36
bot.py
36
bot.py
@@ -70,6 +70,7 @@ LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
||||
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
|
||||
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||
CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "")
|
||||
CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "")
|
||||
CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "")
|
||||
@@ -444,10 +445,17 @@ class DocumentRAG:
|
||||
class MemoryClient:
|
||||
"""Async HTTP client for the memory-service."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
def __init__(self, base_url: str, token: str = ""):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.enabled = bool(base_url)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {}
|
||||
if self.token:
|
||||
h["Authorization"] = f"Bearer {self.token}"
|
||||
return h
|
||||
|
||||
async def store(self, user_id: str, fact: str, source_room: str = ""):
|
||||
if not self.enabled:
|
||||
return
|
||||
@@ -456,6 +464,7 @@ class MemoryClient:
|
||||
await client.post(
|
||||
f"{self.base_url}/memories/store",
|
||||
json={"user_id": user_id, "fact": fact, "source_room": source_room},
|
||||
headers=self._headers(),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Memory store failed", exc_info=True)
|
||||
@@ -468,6 +477,7 @@ class MemoryClient:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/memories/query",
|
||||
json={"user_id": user_id, "query": query, "top_k": top_k},
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("results", [])
|
||||
@@ -480,7 +490,10 @@ class MemoryClient:
|
||||
return 0
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/memories/{user_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("deleted", 0)
|
||||
except Exception:
|
||||
@@ -492,7 +505,10 @@ class MemoryClient:
|
||||
return []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/memories/{user_id}")
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/memories/{user_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("memories", [])
|
||||
except Exception:
|
||||
@@ -512,19 +528,24 @@ class MemoryClient:
|
||||
"chunk_text": chunk_text, "summary": summary,
|
||||
"source_event_id": source_event_id, "original_ts": original_ts,
|
||||
},
|
||||
headers=self._headers(),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Chunk store failed", exc_info=True)
|
||||
|
||||
async def query_chunks(self, query: str, user_id: str = "", room_id: str = "",
|
||||
async def query_chunks(self, query: str, user_id: str, room_id: str = "",
|
||||
top_k: int = 5) -> list[dict]:
|
||||
if not self.enabled:
|
||||
return []
|
||||
if not user_id:
|
||||
logger.error("query_chunks called with empty user_id — returning empty")
|
||||
return []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/chunks/query",
|
||||
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("results", [])
|
||||
@@ -961,7 +982,7 @@ class Bot:
|
||||
self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY,
|
||||
rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN)
|
||||
self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY)
|
||||
self.memory = MemoryClient(MEMORY_SERVICE_URL)
|
||||
self.memory = MemoryClient(MEMORY_SERVICE_URL, token=MEMORY_SERVICE_TOKEN)
|
||||
self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY)
|
||||
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
||||
self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status
|
||||
@@ -2092,7 +2113,10 @@ class Bot:
|
||||
memory_context = self._format_memories(memories)
|
||||
|
||||
# Query relevant conversation chunks (RAG over chat history)
|
||||
chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5)
|
||||
if sender:
|
||||
chunks = await self.memory.query_chunks(search_query, user_id=sender, room_id=room.room_id, top_k=5)
|
||||
else:
|
||||
chunks = []
|
||||
chunk_context = self._format_chunks(chunks)
|
||||
|
||||
# Include room document context (PDFs, Confluence pages, images uploaded to room)
|
||||
|
||||
Reference in New Issue
Block a user