Add scheduled messages/reminders system: - New scheduled_messages table in memory-service with CRUD endpoints - schedule_message, list_reminders, cancel_reminder tools for the bot - Background scheduler loop (30s) sends due reminders automatically - Supports one-time, daily, weekly, weekdays, monthly repeat patterns Make article URL handling non-blocking: - Show 3 options (discuss, text summary, audio) instead of forcing audio wizard - Default to passing article context to AI if user just keeps chatting - New AWAITING_LANGUAGE state for cleaner audio flow FSM Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
766 lines
26 KiB
Python
766 lines
26 KiB
Python
import os
|
|
import logging
|
|
import secrets
|
|
import time
|
|
import hashlib
|
|
import base64
|
|
|
|
import asyncpg
|
|
import httpx
|
|
from cryptography.fernet import Fernet
|
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
|
from pydantic import BaseModel, field_validator
|
|
|
|
logger = logging.getLogger("memory-service")
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
DB_DSN = os.environ.get("DATABASE_URL", "postgresql://memory:memory@memory-db:5432/memories")
|
|
OWNER_DSN = os.environ.get("OWNER_DATABASE_URL", "postgresql://memory:{pw}@memory-db:5432/memories".format(
|
|
pw=os.environ.get("MEMORY_DB_OWNER_PASSWORD", "memory")
|
|
))
|
|
LITELLM_URL = os.environ.get("LITELLM_BASE_URL", "")
|
|
LITELLM_KEY = os.environ.get("LITELLM_API_KEY", "not-needed")
|
|
EMBED_MODEL = os.environ.get("EMBED_MODEL", "text-embedding-3-small")
|
|
EMBED_DIMS = int(os.environ.get("EMBED_DIMS", "1536"))
|
|
DEDUP_THRESHOLD = float(os.environ.get("DEDUP_THRESHOLD", "0.92"))
|
|
ENCRYPTION_KEY = os.environ.get("MEMORY_ENCRYPTION_KEY", "")
|
|
MEMORY_SERVICE_TOKEN = os.environ.get("MEMORY_SERVICE_TOKEN", "")
|
|
|
|
app = FastAPI(title="Memory Service")
|
|
pool: asyncpg.Pool | None = None
|
|
owner_pool: asyncpg.Pool | None = None
|
|
_pool_healthy = True
|
|
|
|
|
|
async def verify_token(authorization: str | None = Header(None)):
|
|
"""Bearer token auth — skipped if MEMORY_SERVICE_TOKEN not configured (dev mode)."""
|
|
if not MEMORY_SERVICE_TOKEN:
|
|
return
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(401, "Missing or invalid Authorization header")
|
|
if not secrets.compare_digest(authorization[7:], MEMORY_SERVICE_TOKEN):
|
|
raise HTTPException(403, "Invalid token")
|
|
|
|
|
|
def _derive_user_key(user_id: str) -> bytes:
|
|
"""Derive a per-user Fernet key from master key + user_id via HMAC-SHA256."""
|
|
if not ENCRYPTION_KEY:
|
|
raise RuntimeError("MEMORY_ENCRYPTION_KEY not set")
|
|
derived = hashlib.pbkdf2_hmac(
|
|
"sha256", ENCRYPTION_KEY.encode(), user_id.encode(), iterations=1
|
|
)
|
|
return base64.urlsafe_b64encode(derived)
|
|
|
|
|
|
def _encrypt(text: str, user_id: str) -> str:
|
|
"""Encrypt text with per-user Fernet key. Returns base64 ciphertext."""
|
|
if not ENCRYPTION_KEY:
|
|
return text
|
|
f = Fernet(_derive_user_key(user_id))
|
|
return f.encrypt(text.encode()).decode()
|
|
|
|
|
|
def _decrypt(ciphertext: str, user_id: str) -> str:
|
|
"""Decrypt ciphertext with per-user Fernet key."""
|
|
if not ENCRYPTION_KEY:
|
|
return ciphertext
|
|
try:
|
|
f = Fernet(_derive_user_key(user_id))
|
|
return f.decrypt(ciphertext.encode()).decode()
|
|
except Exception:
|
|
# Plaintext fallback for not-yet-migrated rows
|
|
return ciphertext
|
|
|
|
|
|
class ScheduleRequest(BaseModel):
|
|
user_id: str
|
|
room_id: str
|
|
message_text: str
|
|
scheduled_at: float # Unix timestamp
|
|
repeat_pattern: str = "once" # once | daily | weekly | weekdays | monthly
|
|
|
|
@field_validator('user_id')
|
|
@classmethod
|
|
def user_id_not_empty(cls, v):
|
|
if not v or not v.strip():
|
|
raise ValueError("user_id is required")
|
|
return v.strip()
|
|
|
|
@field_validator('repeat_pattern')
|
|
@classmethod
|
|
def valid_pattern(cls, v):
|
|
allowed = {"once", "daily", "weekly", "weekdays", "monthly"}
|
|
if v not in allowed:
|
|
raise ValueError(f"repeat_pattern must be one of {allowed}")
|
|
return v
|
|
|
|
|
|
class ScheduleCancelRequest(BaseModel):
|
|
id: int
|
|
user_id: str
|
|
|
|
|
|
class StoreRequest(BaseModel):
|
|
user_id: str
|
|
fact: str
|
|
source_room: str = ""
|
|
|
|
@field_validator('user_id')
|
|
@classmethod
|
|
def user_id_not_empty(cls, v):
|
|
if not v or not v.strip():
|
|
raise ValueError("user_id is required")
|
|
return v.strip()
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
user_id: str
|
|
query: str
|
|
top_k: int = 10
|
|
|
|
@field_validator('user_id')
|
|
@classmethod
|
|
def user_id_not_empty(cls, v):
|
|
if not v or not v.strip():
|
|
raise ValueError("user_id is required")
|
|
return v.strip()
|
|
|
|
|
|
class ChunkStoreRequest(BaseModel):
|
|
user_id: str
|
|
room_id: str
|
|
chunk_text: str
|
|
summary: str
|
|
source_event_id: str = ""
|
|
original_ts: float = 0.0
|
|
|
|
@field_validator('user_id')
|
|
@classmethod
|
|
def user_id_not_empty(cls, v):
|
|
if not v or not v.strip():
|
|
raise ValueError("user_id is required")
|
|
return v.strip()
|
|
|
|
|
|
class ChunkQueryRequest(BaseModel):
|
|
user_id: str # REQUIRED — no default
|
|
room_id: str = ""
|
|
query: str
|
|
top_k: int = 5
|
|
|
|
@field_validator('user_id')
|
|
@classmethod
|
|
def user_id_not_empty(cls, v):
|
|
if not v or not v.strip():
|
|
raise ValueError("user_id is required")
|
|
return v.strip()
|
|
|
|
|
|
class ChunkBulkStoreRequest(BaseModel):
|
|
chunks: list[ChunkStoreRequest]
|
|
|
|
|
|
async def _embed(text: str) -> list[float]:
|
|
"""Get embedding vector from LiteLLM /embeddings endpoint."""
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/embeddings",
|
|
json={"model": EMBED_MODEL, "input": text},
|
|
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()["data"][0]["embedding"]
|
|
|
|
|
|
async def _embed_batch(texts: list[str]) -> list[list[float]]:
|
|
"""Get embedding vectors for a batch of texts."""
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/embeddings",
|
|
json={"model": EMBED_MODEL, "input": texts},
|
|
headers={"Authorization": f"Bearer {LITELLM_KEY}"},
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()["data"]
|
|
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
|
|
|
|
|
|
async def _set_rls_user(conn, user_id: str):
|
|
"""Set the RLS session variable for the current connection."""
|
|
await conn.execute("SELECT set_config('app.current_user_id', $1, false)", user_id)
|
|
|
|
|
|
async def _ensure_pool():
|
|
"""Recreate the connection pool if it was lost."""
|
|
global pool, owner_pool, _pool_healthy
|
|
if pool and _pool_healthy:
|
|
return
|
|
logger.warning("Reconnecting asyncpg pools (healthy=%s, pool=%s)", _pool_healthy, pool is not None)
|
|
try:
|
|
if pool:
|
|
try:
|
|
await pool.close()
|
|
except Exception:
|
|
pass
|
|
if owner_pool:
|
|
try:
|
|
await owner_pool.close()
|
|
except Exception:
|
|
pass
|
|
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
|
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
|
_pool_healthy = True
|
|
logger.info("asyncpg pools reconnected successfully")
|
|
except Exception:
|
|
_pool_healthy = False
|
|
logger.exception("Failed to reconnect asyncpg pools")
|
|
raise
|
|
|
|
|
|
async def _init_db():
|
|
"""Create pgvector extension and memories table if not exists."""
|
|
global pool, _pool_healthy
|
|
# Use owner connection for DDL (CREATE TABLE/INDEX), then create restricted pool
|
|
owner_conn = await asyncpg.connect(OWNER_DSN)
|
|
conn = owner_conn
|
|
try:
|
|
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
await conn.execute(f"""
|
|
CREATE TABLE IF NOT EXISTS memories (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
fact TEXT NOT NULL,
|
|
source_room TEXT DEFAULT '',
|
|
created_at DOUBLE PRECISION NOT NULL,
|
|
embedding vector({EMBED_DIMS}) NOT NULL
|
|
)
|
|
""")
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)
|
|
""")
|
|
await conn.execute(f"""
|
|
CREATE INDEX IF NOT EXISTS idx_memories_embedding
|
|
ON memories USING ivfflat (embedding vector_cosine_ops)
|
|
WITH (lists = 100)
|
|
""")
|
|
# Conversation chunks table for RAG over chat history
|
|
await conn.execute(f"""
|
|
CREATE TABLE IF NOT EXISTS conversation_chunks (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
room_id TEXT NOT NULL,
|
|
chunk_text TEXT NOT NULL,
|
|
summary TEXT NOT NULL,
|
|
source_event_id TEXT DEFAULT '',
|
|
original_ts DOUBLE PRECISION NOT NULL,
|
|
embedding vector({EMBED_DIMS}) NOT NULL
|
|
)
|
|
""")
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_user_id ON conversation_chunks (user_id)
|
|
""")
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_room_id ON conversation_chunks (room_id)
|
|
""")
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_user_room ON conversation_chunks (user_id, room_id)
|
|
""")
|
|
# Scheduled messages table for reminders
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS scheduled_messages (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
room_id TEXT NOT NULL,
|
|
message_text TEXT NOT NULL,
|
|
scheduled_at DOUBLE PRECISION NOT NULL,
|
|
created_at DOUBLE PRECISION NOT NULL,
|
|
status TEXT DEFAULT 'pending',
|
|
repeat_pattern TEXT DEFAULT 'once',
|
|
repeat_interval_seconds INTEGER DEFAULT 0,
|
|
last_sent_at DOUBLE PRECISION DEFAULT 0
|
|
)
|
|
""")
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_scheduled_user_id ON scheduled_messages (user_id)
|
|
""")
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_scheduled_status ON scheduled_messages (status, scheduled_at)
|
|
""")
|
|
finally:
|
|
await owner_conn.close()
|
|
# Create restricted pool for all request handlers (RLS applies)
|
|
pool = await asyncpg.create_pool(DB_DSN, min_size=2, max_size=10)
|
|
# Owner pool for admin queries (bypasses RLS) — 1 connection only
|
|
global owner_pool
|
|
owner_pool = await asyncpg.create_pool(OWNER_DSN, min_size=1, max_size=2)
|
|
_pool_healthy = True
|
|
logger.info("Database initialized (dims=%d, encryption=%s)", EMBED_DIMS, "ON" if ENCRYPTION_KEY else "OFF")
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
await _init_db()
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown():
|
|
if pool:
|
|
await pool.close()
|
|
if owner_pool:
|
|
await owner_pool.close()
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
global _pool_healthy
|
|
try:
|
|
await _ensure_pool()
|
|
async with owner_pool.acquire() as conn:
|
|
mem_count = await conn.fetchval("SELECT count(*) FROM memories")
|
|
chunk_count = await conn.fetchval("SELECT count(*) FROM conversation_chunks")
|
|
sched_count = await conn.fetchval("SELECT count(*) FROM scheduled_messages WHERE status = 'pending'")
|
|
return {
|
|
"status": "ok",
|
|
"total_memories": mem_count,
|
|
"total_chunks": chunk_count,
|
|
"pending_reminders": sched_count,
|
|
"encryption": "on" if ENCRYPTION_KEY else "off",
|
|
}
|
|
except Exception as e:
|
|
_pool_healthy = False
|
|
logger.error("Health check failed: %s", e)
|
|
return {"status": "unhealthy", "error": str(e)}
|
|
|
|
|
|
@app.post("/memories/store")
|
|
async def store_memory(req: StoreRequest, _: None = Depends(verify_token)):
|
|
"""Embed fact, deduplicate by cosine similarity, insert encrypted."""
|
|
if not req.fact.strip():
|
|
raise HTTPException(400, "Empty fact")
|
|
|
|
embedding = await _embed(req.fact)
|
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, req.user_id)
|
|
|
|
# Check for duplicates (cosine similarity > threshold)
|
|
dup = await conn.fetchval(
|
|
"""
|
|
SELECT id FROM memories
|
|
WHERE user_id = $1
|
|
AND 1 - (embedding <=> $2::vector) > $3
|
|
LIMIT 1
|
|
""",
|
|
req.user_id, vec_literal, DEDUP_THRESHOLD,
|
|
)
|
|
if dup:
|
|
logger.info("Duplicate memory for %s (similar to id=%d), skipping", req.user_id, dup)
|
|
return {"stored": False, "reason": "duplicate"}
|
|
|
|
encrypted_fact = _encrypt(req.fact.strip(), req.user_id)
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO memories (user_id, fact, source_room, created_at, embedding)
|
|
VALUES ($1, $2, $3, $4, $5::vector)
|
|
""",
|
|
req.user_id, encrypted_fact, req.source_room, time.time(), vec_literal,
|
|
)
|
|
logger.info("Stored memory for %s: %s", req.user_id, req.fact[:60])
|
|
return {"stored": True}
|
|
|
|
|
|
@app.post("/memories/query")
|
|
async def query_memories(req: QueryRequest, _: None = Depends(verify_token)):
|
|
"""Embed query, return top-K similar facts for user."""
|
|
embedding = await _embed(req.query)
|
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, req.user_id)
|
|
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT fact, source_room, created_at,
|
|
1 - (embedding <=> $1::vector) AS similarity
|
|
FROM memories
|
|
WHERE user_id = $2
|
|
ORDER BY embedding <=> $1::vector
|
|
LIMIT $3
|
|
""",
|
|
vec_literal, req.user_id, req.top_k,
|
|
)
|
|
|
|
results = [
|
|
{
|
|
"fact": _decrypt(r["fact"], req.user_id),
|
|
"source_room": r["source_room"],
|
|
"created_at": r["created_at"],
|
|
"similarity": float(r["similarity"]),
|
|
}
|
|
for r in rows
|
|
]
|
|
return {"results": results}
|
|
|
|
|
|
@app.delete("/memories/{user_id}")
|
|
async def delete_user_memories(user_id: str, _: None = Depends(verify_token)):
|
|
"""GDPR delete — remove all memories for a user."""
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, user_id)
|
|
result = await conn.execute("DELETE FROM memories WHERE user_id = $1", user_id)
|
|
count = int(result.split()[-1])
|
|
logger.info("Deleted %d memories for %s", count, user_id)
|
|
return {"deleted": count}
|
|
|
|
|
|
@app.get("/memories/{user_id}")
|
|
async def list_user_memories(user_id: str, _: None = Depends(verify_token)):
|
|
"""List all memories for a user (for UI/debug)."""
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, user_id)
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT fact, source_room, created_at
|
|
FROM memories
|
|
WHERE user_id = $1
|
|
ORDER BY created_at DESC
|
|
""",
|
|
user_id,
|
|
)
|
|
return {
|
|
"user_id": user_id,
|
|
"count": len(rows),
|
|
"memories": [
|
|
{"fact": _decrypt(r["fact"], user_id), "source_room": r["source_room"], "created_at": r["created_at"]}
|
|
for r in rows
|
|
],
|
|
}
|
|
|
|
|
|
# --- Conversation Chunks ---
|
|
|
|
|
|
@app.post("/chunks/store")
|
|
async def store_chunk(req: ChunkStoreRequest, _: None = Depends(verify_token)):
|
|
"""Store a conversation chunk with its summary embedding, encrypted."""
|
|
if not req.summary.strip():
|
|
raise HTTPException(400, "Empty summary")
|
|
|
|
embedding = await _embed(req.summary)
|
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
ts = req.original_ts or time.time()
|
|
|
|
encrypted_text = _encrypt(req.chunk_text, req.user_id)
|
|
encrypted_summary = _encrypt(req.summary, req.user_id)
|
|
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, req.user_id)
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO conversation_chunks
|
|
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
|
""",
|
|
req.user_id, req.room_id, encrypted_text, encrypted_summary,
|
|
req.source_event_id, ts, vec_literal,
|
|
)
|
|
logger.info("Stored chunk for %s in %s: %s", req.user_id, req.room_id, req.summary[:60])
|
|
return {"stored": True}
|
|
|
|
|
|
@app.post("/chunks/query")
|
|
async def query_chunks(req: ChunkQueryRequest, _: None = Depends(verify_token)):
|
|
"""Semantic search over conversation chunks. Filter by user_id and/or room_id."""
|
|
embedding = await _embed(req.query)
|
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
|
|
# Build WHERE clause — user_id is always required
|
|
conditions = [f"user_id = $2"]
|
|
params: list = [vec_literal, req.user_id]
|
|
idx = 3
|
|
|
|
if req.room_id:
|
|
conditions.append(f"room_id = ${idx}")
|
|
params.append(req.room_id)
|
|
idx += 1
|
|
|
|
where = f"WHERE {' AND '.join(conditions)}"
|
|
params.append(req.top_k)
|
|
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, req.user_id)
|
|
|
|
rows = await conn.fetch(
|
|
f"""
|
|
SELECT chunk_text, summary, room_id, user_id, original_ts, source_event_id,
|
|
1 - (embedding <=> $1::vector) AS similarity
|
|
FROM conversation_chunks
|
|
{where}
|
|
ORDER BY embedding <=> $1::vector
|
|
LIMIT ${idx}
|
|
""",
|
|
*params,
|
|
)
|
|
|
|
results = [
|
|
{
|
|
"chunk_text": _decrypt(r["chunk_text"], r["user_id"]),
|
|
"summary": _decrypt(r["summary"], r["user_id"]),
|
|
"room_id": r["room_id"],
|
|
"user_id": r["user_id"],
|
|
"original_ts": r["original_ts"],
|
|
"source_event_id": r["source_event_id"],
|
|
"similarity": float(r["similarity"]),
|
|
}
|
|
for r in rows
|
|
]
|
|
return {"results": results}
|
|
|
|
|
|
@app.post("/chunks/bulk-store")
|
|
async def bulk_store_chunks(req: ChunkBulkStoreRequest, _: None = Depends(verify_token)):
|
|
"""Batch store conversation chunks. Embeds summaries in batches of 20."""
|
|
if not req.chunks:
|
|
return {"stored": 0}
|
|
|
|
stored = 0
|
|
batch_size = 20
|
|
|
|
for i in range(0, len(req.chunks), batch_size):
|
|
batch = req.chunks[i:i + batch_size]
|
|
summaries = [c.summary.strip() for c in batch]
|
|
|
|
try:
|
|
embeddings = await _embed_batch(summaries)
|
|
except Exception:
|
|
logger.error("Batch embed failed for chunks %d-%d", i, i + len(batch), exc_info=True)
|
|
continue
|
|
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
for chunk, embedding in zip(batch, embeddings):
|
|
await _set_rls_user(conn, chunk.user_id)
|
|
vec_literal = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
ts = chunk.original_ts or time.time()
|
|
encrypted_text = _encrypt(chunk.chunk_text, chunk.user_id)
|
|
encrypted_summary = _encrypt(chunk.summary, chunk.user_id)
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO conversation_chunks
|
|
(user_id, room_id, chunk_text, summary, source_event_id, original_ts, embedding)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)
|
|
""",
|
|
chunk.user_id, chunk.room_id, encrypted_text, encrypted_summary,
|
|
chunk.source_event_id, ts, vec_literal,
|
|
)
|
|
stored += 1
|
|
|
|
logger.info("Bulk stored %d chunks", stored)
|
|
return {"stored": stored}
|
|
|
|
|
|
@app.get("/chunks/{user_id}/count")
|
|
async def count_user_chunks(user_id: str, _: None = Depends(verify_token)):
|
|
"""Count conversation chunks for a user."""
|
|
await _ensure_pool()
|
|
async with pool.acquire() as conn:
|
|
await _set_rls_user(conn, user_id)
|
|
count = await conn.fetchval(
|
|
"SELECT count(*) FROM conversation_chunks WHERE user_id = $1", user_id,
|
|
)
|
|
return {"user_id": user_id, "count": count}
|
|
|
|
|
|
# --- Scheduled Messages ---
|
|
|
|
import calendar
|
|
import datetime
|
|
|
|
|
|
def _compute_repeat_interval(pattern: str) -> int:
|
|
"""Compute repeat_interval_seconds from pattern name."""
|
|
return {
|
|
"once": 0,
|
|
"daily": 86400,
|
|
"weekly": 604800,
|
|
"weekdays": 86400, # special handling in mark-sent
|
|
"monthly": 0, # special handling in mark-sent
|
|
}.get(pattern, 0)
|
|
|
|
|
|
def _next_scheduled_at(current_ts: float, pattern: str) -> float:
|
|
"""Compute the next scheduled_at timestamp for recurring patterns."""
|
|
dt = datetime.datetime.fromtimestamp(current_ts, tz=datetime.timezone.utc)
|
|
|
|
if pattern == "daily":
|
|
return current_ts + 86400.0
|
|
elif pattern == "weekly":
|
|
return current_ts + 604800.0
|
|
elif pattern == "weekdays":
|
|
next_dt = dt + datetime.timedelta(days=1)
|
|
while next_dt.weekday() >= 5: # Skip Sat(5), Sun(6)
|
|
next_dt += datetime.timedelta(days=1)
|
|
return next_dt.timestamp()
|
|
elif pattern == "monthly":
|
|
month = dt.month + 1
|
|
year = dt.year + (month - 1) // 12
|
|
month = (month - 1) % 12 + 1
|
|
day = min(dt.day, calendar.monthrange(year, month)[1])
|
|
return dt.replace(year=year, month=month, day=day).timestamp()
|
|
return current_ts
|
|
|
|
|
|
MAX_REMINDERS_PER_USER = 50
|
|
|
|
|
|
@app.post("/scheduled/create")
|
|
async def create_scheduled(req: ScheduleRequest, _: None = Depends(verify_token)):
|
|
"""Create a new scheduled message/reminder."""
|
|
now = time.time()
|
|
if req.scheduled_at <= now:
|
|
raise HTTPException(400, "scheduled_at must be in the future")
|
|
|
|
# Check max reminders per user
|
|
await _ensure_pool()
|
|
async with owner_pool.acquire() as conn:
|
|
count = await conn.fetchval(
|
|
"SELECT count(*) FROM scheduled_messages WHERE user_id = $1 AND status = 'pending'",
|
|
req.user_id,
|
|
)
|
|
if count >= MAX_REMINDERS_PER_USER:
|
|
raise HTTPException(400, f"Maximum {MAX_REMINDERS_PER_USER} active reminders per user")
|
|
|
|
msg_text = req.message_text[:2000] # Truncate long messages
|
|
interval = _compute_repeat_interval(req.repeat_pattern)
|
|
|
|
row_id = await conn.fetchval(
|
|
"""
|
|
INSERT INTO scheduled_messages
|
|
(user_id, room_id, message_text, scheduled_at, created_at, status, repeat_pattern, repeat_interval_seconds)
|
|
VALUES ($1, $2, $3, $4, $5, 'pending', $6, $7)
|
|
RETURNING id
|
|
""",
|
|
req.user_id, req.room_id, msg_text, req.scheduled_at, now,
|
|
req.repeat_pattern, interval,
|
|
)
|
|
logger.info("Created reminder #%d for %s at %.0f (%s)", row_id, req.user_id, req.scheduled_at, req.repeat_pattern)
|
|
return {"id": row_id, "created": True}
|
|
|
|
|
|
@app.get("/scheduled/{user_id}")
|
|
async def list_scheduled(user_id: str, _: None = Depends(verify_token)):
|
|
"""List all pending/active reminders for a user."""
|
|
await _ensure_pool()
|
|
async with owner_pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT id, message_text, scheduled_at, repeat_pattern, status
|
|
FROM scheduled_messages
|
|
WHERE user_id = $1 AND status = 'pending'
|
|
ORDER BY scheduled_at
|
|
""",
|
|
user_id,
|
|
)
|
|
return {
|
|
"user_id": user_id,
|
|
"reminders": [
|
|
{
|
|
"id": r["id"],
|
|
"message_text": r["message_text"],
|
|
"scheduled_at": r["scheduled_at"],
|
|
"repeat_pattern": r["repeat_pattern"],
|
|
"status": r["status"],
|
|
}
|
|
for r in rows
|
|
],
|
|
}
|
|
|
|
|
|
@app.delete("/scheduled/{user_id}/{reminder_id}")
|
|
async def cancel_scheduled(user_id: str, reminder_id: int, _: None = Depends(verify_token)):
|
|
"""Cancel a reminder. Only the owner can cancel."""
|
|
await _ensure_pool()
|
|
async with owner_pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
UPDATE scheduled_messages SET status = 'cancelled'
|
|
WHERE id = $1 AND user_id = $2 AND status = 'pending'
|
|
""",
|
|
reminder_id, user_id,
|
|
)
|
|
count = int(result.split()[-1])
|
|
if count == 0:
|
|
raise HTTPException(404, "Reminder not found or already cancelled")
|
|
logger.info("Cancelled reminder #%d for %s", reminder_id, user_id)
|
|
return {"cancelled": True, "id": reminder_id}
|
|
|
|
|
|
@app.post("/scheduled/due")
|
|
async def get_due_messages(_: None = Depends(verify_token)):
|
|
"""Return all messages that are due (scheduled_at <= now, status = pending)."""
|
|
now = time.time()
|
|
await _ensure_pool()
|
|
async with owner_pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT id, user_id, room_id, message_text, scheduled_at, repeat_pattern
|
|
FROM scheduled_messages
|
|
WHERE scheduled_at <= $1 AND status = 'pending'
|
|
ORDER BY scheduled_at
|
|
LIMIT 100
|
|
""",
|
|
now,
|
|
)
|
|
return {
|
|
"due": [
|
|
{
|
|
"id": r["id"],
|
|
"user_id": r["user_id"],
|
|
"room_id": r["room_id"],
|
|
"message_text": r["message_text"],
|
|
"scheduled_at": r["scheduled_at"],
|
|
"repeat_pattern": r["repeat_pattern"],
|
|
}
|
|
for r in rows
|
|
],
|
|
}
|
|
|
|
|
|
@app.post("/scheduled/{reminder_id}/mark-sent")
|
|
async def mark_sent(reminder_id: int, _: None = Depends(verify_token)):
|
|
"""Mark a reminder as sent. For recurring, compute next scheduled_at."""
|
|
now = time.time()
|
|
await _ensure_pool()
|
|
async with owner_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"SELECT repeat_pattern, scheduled_at FROM scheduled_messages WHERE id = $1",
|
|
reminder_id,
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "Reminder not found")
|
|
|
|
if row["repeat_pattern"] == "once":
|
|
await conn.execute(
|
|
"UPDATE scheduled_messages SET status = 'sent', last_sent_at = $1 WHERE id = $2",
|
|
now, reminder_id,
|
|
)
|
|
else:
|
|
next_at = _next_scheduled_at(row["scheduled_at"], row["repeat_pattern"])
|
|
await conn.execute(
|
|
"""
|
|
UPDATE scheduled_messages
|
|
SET scheduled_at = $1, last_sent_at = $2
|
|
WHERE id = $3
|
|
""",
|
|
next_at, now, reminder_id,
|
|
)
|
|
logger.info("Marked reminder #%d as sent (pattern=%s)", reminder_id, row["repeat_pattern"])
|
|
return {"marked": True}
|