security: enforce per-user data isolation in memory service
- Make user_id required on all request models with field validators - Always include user_id in WHERE clause for chunk queries (prevents cross-user data leak) - Add bearer token auth on all endpoints except /health - Add composite index on (user_id, room_id) for conversation_chunks - Bot: guard query_chunks with sender check, pass room_id, send auth token - Docker: pass MEMORY_SERVICE_TOKEN to both bot and memory-service Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
36
bot.py
36
bot.py
@@ -70,6 +70,7 @@ LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
|||||||
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
||||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "claude-sonnet")
|
||||||
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
|
MEMORY_SERVICE_URL = os.environ.get("MEMORY_SERVICE_URL", "http://memory-service:8090")
|
||||||
|
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||||
CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "")
|
CONFLUENCE_URL = os.environ.get("CONFLUENCE_BASE_URL", "")
|
||||||
CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "")
|
CONFLUENCE_USER = os.environ.get("CONFLUENCE_USER", "")
|
||||||
CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "")
|
CONFLUENCE_TOKEN = os.environ.get("CONFLUENCE_TOKEN", "")
|
||||||
@@ -444,10 +445,17 @@ class DocumentRAG:
|
|||||||
class MemoryClient:
|
class MemoryClient:
|
||||||
"""Async HTTP client for the memory-service."""
|
"""Async HTTP client for the memory-service."""
|
||||||
|
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str, token: str = ""):
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.token = token
|
||||||
self.enabled = bool(base_url)
|
self.enabled = bool(base_url)
|
||||||
|
|
||||||
|
def _headers(self) -> dict:
|
||||||
|
h = {}
|
||||||
|
if self.token:
|
||||||
|
h["Authorization"] = f"Bearer {self.token}"
|
||||||
|
return h
|
||||||
|
|
||||||
async def store(self, user_id: str, fact: str, source_room: str = ""):
|
async def store(self, user_id: str, fact: str, source_room: str = ""):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
@@ -456,6 +464,7 @@ class MemoryClient:
|
|||||||
await client.post(
|
await client.post(
|
||||||
f"{self.base_url}/memories/store",
|
f"{self.base_url}/memories/store",
|
||||||
json={"user_id": user_id, "fact": fact, "source_room": source_room},
|
json={"user_id": user_id, "fact": fact, "source_room": source_room},
|
||||||
|
headers=self._headers(),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Memory store failed", exc_info=True)
|
logger.warning("Memory store failed", exc_info=True)
|
||||||
@@ -468,6 +477,7 @@ class MemoryClient:
|
|||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{self.base_url}/memories/query",
|
f"{self.base_url}/memories/query",
|
||||||
json={"user_id": user_id, "query": query, "top_k": top_k},
|
json={"user_id": user_id, "query": query, "top_k": top_k},
|
||||||
|
headers=self._headers(),
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json().get("results", [])
|
return resp.json().get("results", [])
|
||||||
@@ -480,7 +490,10 @@ class MemoryClient:
|
|||||||
return 0
|
return 0
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
resp = await client.delete(f"{self.base_url}/memories/{user_id}")
|
resp = await client.delete(
|
||||||
|
f"{self.base_url}/memories/{user_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json().get("deleted", 0)
|
return resp.json().get("deleted", 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -492,7 +505,10 @@ class MemoryClient:
|
|||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
resp = await client.get(f"{self.base_url}/memories/{user_id}")
|
resp = await client.get(
|
||||||
|
f"{self.base_url}/memories/{user_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json().get("memories", [])
|
return resp.json().get("memories", [])
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -512,19 +528,24 @@ class MemoryClient:
|
|||||||
"chunk_text": chunk_text, "summary": summary,
|
"chunk_text": chunk_text, "summary": summary,
|
||||||
"source_event_id": source_event_id, "original_ts": original_ts,
|
"source_event_id": source_event_id, "original_ts": original_ts,
|
||||||
},
|
},
|
||||||
|
headers=self._headers(),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Chunk store failed", exc_info=True)
|
logger.warning("Chunk store failed", exc_info=True)
|
||||||
|
|
||||||
async def query_chunks(self, query: str, user_id: str = "", room_id: str = "",
|
async def query_chunks(self, query: str, user_id: str, room_id: str = "",
|
||||||
top_k: int = 5) -> list[dict]:
|
top_k: int = 5) -> list[dict]:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return []
|
return []
|
||||||
|
if not user_id:
|
||||||
|
logger.error("query_chunks called with empty user_id — returning empty")
|
||||||
|
return []
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{self.base_url}/chunks/query",
|
f"{self.base_url}/chunks/query",
|
||||||
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
|
json={"user_id": user_id, "room_id": room_id, "query": query, "top_k": top_k},
|
||||||
|
headers=self._headers(),
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json().get("results", [])
|
return resp.json().get("results", [])
|
||||||
@@ -961,7 +982,7 @@ class Bot:
|
|||||||
self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY,
|
self.rag = DocumentRAG(PORTAL_URL, BOT_API_KEY,
|
||||||
rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN)
|
rag_endpoint=RAG_ENDPOINT, rag_auth_token=RAG_AUTH_TOKEN)
|
||||||
self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY)
|
self.key_manager = RAGKeyManager(self.client, PORTAL_URL, BOT_API_KEY)
|
||||||
self.memory = MemoryClient(MEMORY_SERVICE_URL)
|
self.memory = MemoryClient(MEMORY_SERVICE_URL, token=MEMORY_SERVICE_TOKEN)
|
||||||
self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY)
|
self.atlassian = AtlassianClient(PORTAL_URL, BOT_API_KEY)
|
||||||
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
self.llm = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_KEY) if LITELLM_URL else None
|
||||||
self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status
|
self._documents_cache: dict[str, str | None] = {} # matrix_user_id -> connected status
|
||||||
@@ -2092,7 +2113,10 @@ class Bot:
|
|||||||
memory_context = self._format_memories(memories)
|
memory_context = self._format_memories(memories)
|
||||||
|
|
||||||
# Query relevant conversation chunks (RAG over chat history)
|
# Query relevant conversation chunks (RAG over chat history)
|
||||||
chunks = await self.memory.query_chunks(search_query, user_id=sender or "", top_k=5)
|
if sender:
|
||||||
|
chunks = await self.memory.query_chunks(search_query, user_id=sender, room_id=room.room_id, top_k=5)
|
||||||
|
else:
|
||||||
|
chunks = []
|
||||||
chunk_context = self._format_chunks(chunks)
|
chunk_context = self._format_chunks(chunks)
|
||||||
|
|
||||||
# Include room document context (PDFs, Confluence pages, images uploaded to room)
|
# Include room document context (PDFs, Confluence pages, images uploaded to room)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ services:
|
|||||||
- LITELLM_API_KEY
|
- LITELLM_API_KEY
|
||||||
- DEFAULT_MODEL
|
- DEFAULT_MODEL
|
||||||
- MEMORY_SERVICE_URL=http://memory-service:8090
|
- MEMORY_SERVICE_URL=http://memory-service:8090
|
||||||
|
- MEMORY_SERVICE_TOKEN
|
||||||
- PORTAL_URL
|
- PORTAL_URL
|
||||||
- BOT_API_KEY
|
- BOT_API_KEY
|
||||||
volumes:
|
volumes:
|
||||||
@@ -60,6 +61,7 @@ services:
|
|||||||
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
LITELLM_BASE_URL: ${LITELLM_BASE_URL}
|
||||||
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
|
LITELLM_API_KEY: ${LITELLM_MASTER_KEY}
|
||||||
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
EMBED_MODEL: ${EMBED_MODEL:-text-embedding-3-small}
|
||||||
|
MEMORY_SERVICE_TOKEN: ${MEMORY_SERVICE_TOKEN:-}
|
||||||
depends_on:
|
depends_on:
|
||||||
memory-db:
|
memory-db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import secrets
|
||||||
import time
|
import time
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
@@ -7,8 +8,8 @@ import base64
|
|||||||
import asyncpg
|
import asyncpg
|
||||||
import httpx
|
import httpx
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
logger = logging.getLogger("memory-service")
|
logger = logging.getLogger("memory-service")
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -23,12 +24,23 @@ EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
|||||||
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
||||||
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
||||||
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
||||||
|
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
||||||
|
|
||||||
app = FastAPI(title="Memory Service")
|
app = FastAPI(title="Memory Service")
|
||||||
pool: asyncpg.Pool | None = None
|
pool: asyncpg.Pool | None = None
|
||||||
owner_pool: asyncpg.Pool | None = None
|
owner_pool: asyncpg.Pool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_token(authorization: str | None = Header(None)):
|
||||||
|
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
|
||||||
|
if not MEMORY_SERVICE_TOKEN:
|
||||||
|
return
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(401, "Missing or invalid Authorization header")
|
||||||
|
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
|
||||||
|
raise HTTPException(403, "Invalid token")
|
||||||
|
|
||||||
|
|
||||||
def _derive_user_key(user_id: str) -> bytes:
|
def _derive_user_key(user_id: str) -> bytes:
|
||||||
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
||||||
if not ENCRYPTION_KEY:
|
if not ENCRYPTION_KEY:
|
||||||
@@ -64,12 +76,26 @@ class StoreRequest(BaseModel):
|
|||||||
fact: str
|
fact: str
|
||||||
source_room: str = ""
|
source_room: str = ""
|
||||||
|
|
||||||
|
@field_validator('user_id')
|
||||||
|
@classmethod
|
||||||
|
def user_id_not_empty(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("user_id is required")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
query: str
|
query: str
|
||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
|
|
||||||
|
@field_validator('user_id')
|
||||||
|
@classmethod
|
||||||
|
def user_id_not_empty(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("user_id is required")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class ChunkStoreRequest(BaseModel):
|
class ChunkStoreRequest(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
@@ -79,13 +105,27 @@ class ChunkStoreRequest(BaseModel):
|
|||||||
source_event_id: str = ""
|
source_event_id: str = ""
|
||||||
original_ts: float = 0.0
|
original_ts: float = 0.0
|
||||||
|
|
||||||
|
@field_validator('user_id')
|
||||||
|
@classmethod
|
||||||
|
def user_id_not_empty(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("user_id is required")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class ChunkQueryRequest(BaseModel):
|
class ChunkQueryRequest(BaseModel):
|
||||||
user_id: str = ""
|
user_id: str # REQUIRED — no default
|
||||||
room_id: str = ""
|
room_id: str = ""
|
||||||
query: str
|
query: str
|
||||||
top_k: int = 5
|
top_k: int = 5
|
||||||
|
|
||||||
|
@field_validator('user_id')
|
||||||
|
@classmethod
|
||||||
|
def user_id_not_empty(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("user_id is required")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class ChunkBulkStoreRequest(BaseModel):
|
class ChunkBulkStoreRequest(BaseModel):
|
||||||
chunks: list[ChunkStoreRequest]
|
chunks: list[ChunkStoreRequest]
|
||||||
@@ -166,6 +206,9 @@ async def _init_db():
|
|||||||
await conn.execute("""
|
await conn.execute("""
|
||||||
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
||||||
""")
|
""")
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
|
||||||
|
""")
|
||||||
finally:
|
finally:
|
||||||
await owner_conn.close()
|
await owner_conn.close()
|
||||||
# Create restricted pool for all request handlers (RLS applies)
|
# Create restricted pool for all request handlers (RLS applies)
|
||||||
@@ -205,7 +248,7 @@ async def health():
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/memories/store")
|
@app.post("/memories/store")
|
||||||
async def store_memory(req: StoreRequest):
|
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
|
||||||
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
||||||
if not req.fact.strip():
|
if not req.fact.strip():
|
||||||
raise HTTPException(400, "Empty fact")
|
raise HTTPException(400, "Empty fact")
|
||||||
@@ -243,7 +286,7 @@ async def store_memory(req: StoreRequest):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/memories/query")
|
@app.post("/memories/query")
|
||||||
async def query_memories(req: QueryRequest):
|
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
||||||
"""Embed query, return top-K similar facts for user."""
|
"""Embed query, return top-K similar facts for user."""
|
||||||
embedding = await _embed(req.query)
|
embedding = await _embed(req.query)
|
||||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
@@ -276,7 +319,7 @@ async def query_memories(req: QueryRequest):
|
|||||||
|
|
||||||
|
|
||||||
@app.delete("/memories/{user_id}")
|
@app.delete("/memories/{user_id}")
|
||||||
async def delete_user_memories(user_id: str):
|
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||||
"""GDPR delete — remove all memories for a user."""
|
"""GDPR delete — remove all memories for a user."""
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await _set_rls_user(conn, user_id)
|
await _set_rls_user(conn, user_id)
|
||||||
@@ -287,7 +330,7 @@ async def delete_user_memories(user_id: str):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/memories/{user_id}")
|
@app.get("/memories/{user_id}")
|
||||||
async def list_user_memories(user_id: str):
|
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
|
||||||
"""List all memories for a user (for UI/debug)."""
|
"""List all memories for a user (for UI/debug)."""
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await _set_rls_user(conn, user_id)
|
await _set_rls_user(conn, user_id)
|
||||||
@@ -314,7 +357,7 @@ async def list_user_memories(user_id: str):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/chunks/store")
|
@app.post("/chunks/store")
|
||||||
async def store_chunk(req: ChunkStoreRequest):
|
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
|
||||||
"""Store a conversation chunk with its summary embedding, encrypted."""
|
"""Store a conversation chunk with its summary embedding, encrypted."""
|
||||||
if not req.summary.strip():
|
if not req.summary.strip():
|
||||||
raise HTTPException(400, "Empty summary")
|
raise HTTPException(400, "Empty summary")
|
||||||
@@ -342,34 +385,26 @@ async def store_chunk(req: ChunkStoreRequest):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/chunks/query")
|
@app.post("/chunks/query")
|
||||||
async def query_chunks(req: ChunkQueryRequest):
|
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
|
||||||
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
||||||
embedding = await _embed(req.query)
|
embedding = await _embed(req.query)
|
||||||
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||||
|
|
||||||
# Build WHERE clause dynamically based on provided filters
|
# Build WHERE clause — user_id is always required
|
||||||
conditions = []
|
conditions = [f"user_id = $2"]
|
||||||
params: list = [vec_literal]
|
params: list = [vec_literal, req.user_id]
|
||||||
idx = 2
|
idx = 3
|
||||||
|
|
||||||
if req.user_id:
|
|
||||||
conditions.append(f"user_id = ${idx}")
|
|
||||||
params.append(req.user_id)
|
|
||||||
idx += 1
|
|
||||||
if req.room_id:
|
if req.room_id:
|
||||||
conditions.append(f"room_id = ${idx}")
|
conditions.append(f"room_id = ${idx}")
|
||||||
params.append(req.room_id)
|
params.append(req.room_id)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
where = f"WHERE {' AND '.join(conditions)}"
|
||||||
params.append(req.top_k)
|
params.append(req.top_k)
|
||||||
|
|
||||||
# Determine user_id for RLS and decryption
|
|
||||||
rls_user = req.user_id if req.user_id else ""
|
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
if rls_user:
|
await _set_rls_user(conn, req.user_id)
|
||||||
await _set_rls_user(conn, rls_user)
|
|
||||||
|
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
f"""
|
f"""
|
||||||
@@ -399,7 +434,7 @@ async def query_chunks(req: ChunkQueryRequest):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/chunks/bulk-store")
|
@app.post("/chunks/bulk-store")
|
||||||
async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
|
||||||
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
||||||
if not req.chunks:
|
if not req.chunks:
|
||||||
return {"stored": 0}
|
return {"stored": 0}
|
||||||
@@ -440,7 +475,7 @@ async def bulk_store_chunks(req: ChunkBulkStoreRequest):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/chunks/{user_id}/count")
|
@app.get("/chunks/{user_id}/count")
|
||||||
async def count_user_chunks(user_id: str):
|
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
|
||||||
"""Count conversation chunks for a user."""
|
"""Count conversation chunks for a user."""
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await _set_rls_user(conn, user_id)
|
await _set_rls_user(conn, user_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user