fix(core): fix end_user_id reference and add task status tracking
- Fix write_router to use actual_end_user_id instead of end_user_id - Add task status tracking via Redis in scheduler - Expose task_id in memory write response - Fix logging import path in scheduler
This commit is contained in:
@@ -6,7 +6,7 @@ import redis
|
||||
|
||||
from app.core.config import settings
|
||||
from celery_app import celery_app
|
||||
from core.logging_config import get_named_logger
|
||||
from app.core.logging_config import get_named_logger
|
||||
|
||||
logger = get_named_logger("task_scheduler")
|
||||
|
||||
@@ -62,11 +62,34 @@ class RedisTaskScheduler:
|
||||
"params": json.dumps(params),
|
||||
}
|
||||
)
|
||||
self.redis.set(
|
||||
f"task_tracker:{msg_id}",
|
||||
json.dumps({"status": "QUEUED", "task_id": None}),
|
||||
ex=86400
|
||||
)
|
||||
return msg_id
|
||||
except Exception as e:
|
||||
logger.error("Push task exception %s", e, exc_info=True)
|
||||
raise e
|
||||
|
||||
def get_task_status(self, msg_id: str) -> dict:
|
||||
raw = self.redis.get(f"task_tracker:{msg_id}")
|
||||
if raw is None:
|
||||
return {"status": "NOT_FOUND"}
|
||||
|
||||
tracker = json.loads(raw)
|
||||
status = tracker["status"]
|
||||
task_id = tracker.get("task_id")
|
||||
result_content = tracker.get("result") or {}
|
||||
if status == "DISPATCHED" and task_id:
|
||||
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
||||
if result_raw:
|
||||
result_data = json.loads(result_raw)
|
||||
status = result_data.get("status", status)
|
||||
result_content = result_data.get("result")
|
||||
|
||||
return {"status": status, "task_id": task_id, "result": result_content}
|
||||
|
||||
def _cleanup_finished(self):
|
||||
pending = self.redis.hgetall(PENDING_HASH)
|
||||
if not pending:
|
||||
@@ -91,6 +114,7 @@ class RedisTaskScheduler:
|
||||
age = now - dispatched_at
|
||||
|
||||
should_cleanup = False
|
||||
result_data = None
|
||||
if raw_result is not None:
|
||||
result_data = json.loads(raw_result)
|
||||
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
||||
@@ -104,8 +128,20 @@ class RedisTaskScheduler:
|
||||
)
|
||||
|
||||
if should_cleanup:
|
||||
final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
||||
cleanup_pipe.delete(lock_key)
|
||||
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
||||
tracker_msg_id = meta.get("msg_id")
|
||||
if tracker_msg_id:
|
||||
cleanup_pipe.set(
|
||||
f"task_tracker:{tracker_msg_id}",
|
||||
json.dumps({
|
||||
"status": final_status,
|
||||
"task_id": task_id,
|
||||
"result": result_data.get("result") or {}
|
||||
}),
|
||||
ex=86400,
|
||||
)
|
||||
has_cleanup = True
|
||||
except Exception as e:
|
||||
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
||||
@@ -125,9 +161,15 @@ class RedisTaskScheduler:
|
||||
pipe.set(lock_key, task.id, ex=3600)
|
||||
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
||||
"lock_key": lock_key,
|
||||
"dispatched_at": time.time()
|
||||
"dispatched_at": time.time(),
|
||||
"msg_id": msg_id
|
||||
}))
|
||||
pipe.xdel(STREAM_KEY, msg_id)
|
||||
pipe.set(
|
||||
f"task_tracker:{msg_id}",
|
||||
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
||||
ex=86400,
|
||||
)
|
||||
pipe.execute()
|
||||
self.dispatched += 1
|
||||
logger.info("Task dispatched: %s", task.id)
|
||||
@@ -183,8 +225,8 @@ class RedisTaskScheduler:
|
||||
if locked or lock_key in deliver_keys:
|
||||
continue
|
||||
|
||||
key = self._dispatch(msg_id, msg_data)
|
||||
if key:
|
||||
dispatched_successfully = self._dispatch(msg_id, msg_data)
|
||||
if dispatched_successfully:
|
||||
deliver_keys.add(lock_key)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.schemas.memory_api_schema import (
|
||||
MemoryWriteSyncResponse,
|
||||
)
|
||||
from app.services.memory_api_service import MemoryAPIService
|
||||
from celery_task_scheduler import scheduler
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
@@ -86,7 +87,7 @@ async def write_memory(
|
||||
user_rag_memory_id=payload.user_rag_memory_id,
|
||||
)
|
||||
|
||||
logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}")
|
||||
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||
|
||||
|
||||
@@ -105,8 +106,7 @@ async def get_write_task_status(
|
||||
"""
|
||||
logger.info(f"Write task status check - task_id: {task_id}")
|
||||
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
result = get_task_memory_write_result(task_id)
|
||||
result = scheduler.get_task_status(task_id)
|
||||
|
||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||
|
||||
|
||||
@@ -94,9 +94,9 @@ async def write(
|
||||
# )
|
||||
scheduler.push_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
end_user_id,
|
||||
actual_end_user_id,
|
||||
{
|
||||
"end_user_id": end_user_id,
|
||||
"end_user_id": actual_end_user_id,
|
||||
"message": structured_messages,
|
||||
"config_id": str(actual_config_id),
|
||||
"storage_type": storage_type,
|
||||
|
||||
@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
|
||||
"""Response schema for memory write operation.
|
||||
|
||||
Attributes:
|
||||
task_id: Celery task ID for status polling
|
||||
status: Initial task status (PENDING)
|
||||
task_id: task ID for status polling
|
||||
status: Initial task status (QUEUED)
|
||||
end_user_id: End user ID the write was submitted for
|
||||
"""
|
||||
task_id: str = Field(..., description="Celery task ID for polling")
|
||||
status: str = Field(..., description="Task status: PENDING")
|
||||
task_id: str = Field(..., description="task ID for polling")
|
||||
status: str = Field(..., description="Task status: QUEUED")
|
||||
end_user_id: str = Field(..., description="End user ID")
|
||||
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ class MemoryAPIService:
|
||||
# storage_type,
|
||||
# user_rag_memory_id or "",
|
||||
# )
|
||||
scheduler.push_task(
|
||||
task_id = scheduler.push_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
end_user_id,
|
||||
{
|
||||
@@ -190,8 +190,8 @@ class MemoryAPIService:
|
||||
logger.info(f"Memory write task submitted, end_user_id={end_user_id}")
|
||||
|
||||
return {
|
||||
# "task_id": task.id,
|
||||
"status": "PENDING",
|
||||
"task_id": task_id,
|
||||
"status": "QUEUED",
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user