fix(core): fix end_user_id reference and add task status tracking

- Fix write_router to use actual_end_user_id instead of end_user_id
- Add task status tracking via Redis in scheduler
- Expose task_id in memory write response
- Fix logging import path in scheduler
This commit is contained in:
Eternity
2026-04-22 17:46:38 +08:00
parent c5ae82c3c2
commit f93ec8d609
5 changed files with 58 additions and 16 deletions

View File

@@ -6,7 +6,7 @@ import redis
from app.core.config import settings
from celery_app import celery_app
from core.logging_config import get_named_logger
from app.core.logging_config import get_named_logger
logger = get_named_logger("task_scheduler")
@@ -62,11 +62,34 @@ class RedisTaskScheduler:
"params": json.dumps(params),
}
)
self.redis.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "QUEUED", "task_id": None}),
ex=86400
)
return msg_id
except Exception as e:
logger.error("Push task exception %s", e, exc_info=True)
raise e
def get_task_status(self, msg_id: str) -> dict:
raw = self.redis.get(f"task_tracker:{msg_id}")
if raw is None:
return {"status": "NOT_FOUND"}
tracker = json.loads(raw)
status = tracker["status"]
task_id = tracker.get("task_id")
result_content = tracker.get("result") or {}
if status == "DISPATCHED" and task_id:
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
if result_raw:
result_data = json.loads(result_raw)
status = result_data.get("status", status)
result_content = result_data.get("result")
return {"status": status, "task_id": task_id, "result": result_content}
def _cleanup_finished(self):
pending = self.redis.hgetall(PENDING_HASH)
if not pending:
@@ -91,6 +114,7 @@ class RedisTaskScheduler:
age = now - dispatched_at
should_cleanup = False
result_data = None
if raw_result is not None:
result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
@@ -104,8 +128,20 @@ class RedisTaskScheduler:
)
if should_cleanup:
final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
cleanup_pipe.delete(lock_key)
cleanup_pipe.hdel(PENDING_HASH, task_id)
tracker_msg_id = meta.get("msg_id")
if tracker_msg_id:
cleanup_pipe.set(
f"task_tracker:{tracker_msg_id}",
json.dumps({
"status": final_status,
"task_id": task_id,
"result": result_data.get("result") or {}
}),
ex=86400,
)
has_cleanup = True
except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
@@ -125,9 +161,15 @@ class RedisTaskScheduler:
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time()
"dispatched_at": time.time(),
"msg_id": msg_id
}))
pipe.xdel(STREAM_KEY, msg_id)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
self.dispatched += 1
logger.info("Task dispatched: %s", task.id)
@@ -183,8 +225,8 @@ class RedisTaskScheduler:
if locked or lock_key in deliver_keys:
continue
key = self._dispatch(msg_id, msg_data)
if key:
dispatched_successfully = self._dispatch(msg_id, msg_data)
if dispatched_successfully:
deliver_keys.add(lock_key)
time.sleep(0.1)

View File

@@ -18,6 +18,7 @@ from app.schemas.memory_api_schema import (
MemoryWriteSyncResponse,
)
from app.services.memory_api_service import MemoryAPIService
from celery_task_scheduler import scheduler
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
logger = get_business_logger()
@@ -86,7 +87,7 @@ async def write_memory(
user_rag_memory_id=payload.user_rag_memory_id,
)
logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}")
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
@@ -105,8 +106,7 @@ async def get_write_task_status(
"""
logger.info(f"Write task status check - task_id: {task_id}")
from app.services.task_service import get_task_memory_write_result
result = get_task_memory_write_result(task_id)
result = scheduler.get_task_status(task_id)
return success(data=_sanitize_task_result(result), msg="Task status retrieved")

View File

@@ -94,9 +94,9 @@ async def write(
# )
scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
actual_end_user_id,
{
"end_user_id": end_user_id,
"end_user_id": actual_end_user_id,
"message": structured_messages,
"config_id": str(actual_config_id),
"storage_type": storage_type,

View File

@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
"""Response schema for memory write operation.
Attributes:
task_id: Celery task ID for status polling
status: Initial task status (PENDING)
task_id: task ID for status polling
status: Initial task status (QUEUED)
end_user_id: End user ID the write was submitted for
"""
task_id: str = Field(..., description="Celery task ID for polling")
status: str = Field(..., description="Task status: PENDING")
task_id: str = Field(..., description="task ID for polling")
status: str = Field(..., description="Task status: QUEUED")
end_user_id: str = Field(..., description="End user ID")

View File

@@ -175,7 +175,7 @@ class MemoryAPIService:
# storage_type,
# user_rag_memory_id or "",
# )
scheduler.push_task(
task_id = scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
{
@@ -190,8 +190,8 @@ class MemoryAPIService:
logger.info(f"Memory write task submitted, end_user_id={end_user_id}")
return {
# "task_id": task.id,
"status": "PENDING",
"task_id": task_id,
"status": "QUEUED",
"end_user_id": end_user_id,
}