fix(core): fix end_user_id reference and add task status tracking
- Fix write_router to use actual_end_user_id instead of end_user_id - Add task status tracking via Redis in scheduler - Expose task_id in memory write response - Fix logging import path in scheduler
This commit is contained in:
@@ -6,7 +6,7 @@ import redis
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from celery_app import celery_app
|
from celery_app import celery_app
|
||||||
from core.logging_config import get_named_logger
|
from app.core.logging_config import get_named_logger
|
||||||
|
|
||||||
logger = get_named_logger("task_scheduler")
|
logger = get_named_logger("task_scheduler")
|
||||||
|
|
||||||
@@ -62,11 +62,34 @@ class RedisTaskScheduler:
|
|||||||
"params": json.dumps(params),
|
"params": json.dumps(params),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
self.redis.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "QUEUED", "task_id": None}),
|
||||||
|
ex=86400
|
||||||
|
)
|
||||||
return msg_id
|
return msg_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Push task exception %s", e, exc_info=True)
|
logger.error("Push task exception %s", e, exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def get_task_status(self, msg_id: str) -> dict:
|
||||||
|
raw = self.redis.get(f"task_tracker:{msg_id}")
|
||||||
|
if raw is None:
|
||||||
|
return {"status": "NOT_FOUND"}
|
||||||
|
|
||||||
|
tracker = json.loads(raw)
|
||||||
|
status = tracker["status"]
|
||||||
|
task_id = tracker.get("task_id")
|
||||||
|
result_content = tracker.get("result") or {}
|
||||||
|
if status == "DISPATCHED" and task_id:
|
||||||
|
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
|
||||||
|
if result_raw:
|
||||||
|
result_data = json.loads(result_raw)
|
||||||
|
status = result_data.get("status", status)
|
||||||
|
result_content = result_data.get("result")
|
||||||
|
|
||||||
|
return {"status": status, "task_id": task_id, "result": result_content}
|
||||||
|
|
||||||
def _cleanup_finished(self):
|
def _cleanup_finished(self):
|
||||||
pending = self.redis.hgetall(PENDING_HASH)
|
pending = self.redis.hgetall(PENDING_HASH)
|
||||||
if not pending:
|
if not pending:
|
||||||
@@ -91,6 +114,7 @@ class RedisTaskScheduler:
|
|||||||
age = now - dispatched_at
|
age = now - dispatched_at
|
||||||
|
|
||||||
should_cleanup = False
|
should_cleanup = False
|
||||||
|
result_data = None
|
||||||
if raw_result is not None:
|
if raw_result is not None:
|
||||||
result_data = json.loads(raw_result)
|
result_data = json.loads(raw_result)
|
||||||
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
|
||||||
@@ -104,8 +128,20 @@ class RedisTaskScheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if should_cleanup:
|
if should_cleanup:
|
||||||
|
final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
|
||||||
cleanup_pipe.delete(lock_key)
|
cleanup_pipe.delete(lock_key)
|
||||||
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
cleanup_pipe.hdel(PENDING_HASH, task_id)
|
||||||
|
tracker_msg_id = meta.get("msg_id")
|
||||||
|
if tracker_msg_id:
|
||||||
|
cleanup_pipe.set(
|
||||||
|
f"task_tracker:{tracker_msg_id}",
|
||||||
|
json.dumps({
|
||||||
|
"status": final_status,
|
||||||
|
"task_id": task_id,
|
||||||
|
"result": result_data.get("result") or {}
|
||||||
|
}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
has_cleanup = True
|
has_cleanup = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
|
||||||
@@ -125,9 +161,15 @@ class RedisTaskScheduler:
|
|||||||
pipe.set(lock_key, task.id, ex=3600)
|
pipe.set(lock_key, task.id, ex=3600)
|
||||||
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
||||||
"lock_key": lock_key,
|
"lock_key": lock_key,
|
||||||
"dispatched_at": time.time()
|
"dispatched_at": time.time(),
|
||||||
|
"msg_id": msg_id
|
||||||
}))
|
}))
|
||||||
pipe.xdel(STREAM_KEY, msg_id)
|
pipe.xdel(STREAM_KEY, msg_id)
|
||||||
|
pipe.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
pipe.execute()
|
pipe.execute()
|
||||||
self.dispatched += 1
|
self.dispatched += 1
|
||||||
logger.info("Task dispatched: %s", task.id)
|
logger.info("Task dispatched: %s", task.id)
|
||||||
@@ -183,8 +225,8 @@ class RedisTaskScheduler:
|
|||||||
if locked or lock_key in deliver_keys:
|
if locked or lock_key in deliver_keys:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = self._dispatch(msg_id, msg_data)
|
dispatched_successfully = self._dispatch(msg_id, msg_data)
|
||||||
if key:
|
if dispatched_successfully:
|
||||||
deliver_keys.add(lock_key)
|
deliver_keys.add(lock_key)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from app.schemas.memory_api_schema import (
|
|||||||
MemoryWriteSyncResponse,
|
MemoryWriteSyncResponse,
|
||||||
)
|
)
|
||||||
from app.services.memory_api_service import MemoryAPIService
|
from app.services.memory_api_service import MemoryAPIService
|
||||||
|
from celery_task_scheduler import scheduler
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -86,7 +87,7 @@ async def write_memory(
|
|||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}")
|
||||||
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")
|
||||||
|
|
||||||
|
|
||||||
@@ -105,8 +106,7 @@ async def get_write_task_status(
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Write task status check - task_id: {task_id}")
|
logger.info(f"Write task status check - task_id: {task_id}")
|
||||||
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
result = scheduler.get_task_status(task_id)
|
||||||
result = get_task_memory_write_result(task_id)
|
|
||||||
|
|
||||||
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
return success(data=_sanitize_task_result(result), msg="Task status retrieved")
|
||||||
|
|
||||||
|
|||||||
@@ -94,9 +94,9 @@ async def write(
|
|||||||
# )
|
# )
|
||||||
scheduler.push_task(
|
scheduler.push_task(
|
||||||
"app.core.memory.agent.write_message",
|
"app.core.memory.agent.write_message",
|
||||||
end_user_id,
|
actual_end_user_id,
|
||||||
{
|
{
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": actual_end_user_id,
|
||||||
"message": structured_messages,
|
"message": structured_messages,
|
||||||
"config_id": str(actual_config_id),
|
"config_id": str(actual_config_id),
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
|
|||||||
@@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel):
|
|||||||
"""Response schema for memory write operation.
|
"""Response schema for memory write operation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task_id: Celery task ID for status polling
|
task_id: task ID for status polling
|
||||||
status: Initial task status (PENDING)
|
status: Initial task status (QUEUED)
|
||||||
end_user_id: End user ID the write was submitted for
|
end_user_id: End user ID the write was submitted for
|
||||||
"""
|
"""
|
||||||
task_id: str = Field(..., description="Celery task ID for polling")
|
task_id: str = Field(..., description="task ID for polling")
|
||||||
status: str = Field(..., description="Task status: PENDING")
|
status: str = Field(..., description="Task status: QUEUED")
|
||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ class MemoryAPIService:
|
|||||||
# storage_type,
|
# storage_type,
|
||||||
# user_rag_memory_id or "",
|
# user_rag_memory_id or "",
|
||||||
# )
|
# )
|
||||||
scheduler.push_task(
|
task_id = scheduler.push_task(
|
||||||
"app.core.memory.agent.write_message",
|
"app.core.memory.agent.write_message",
|
||||||
end_user_id,
|
end_user_id,
|
||||||
{
|
{
|
||||||
@@ -190,8 +190,8 @@ class MemoryAPIService:
|
|||||||
logger.info(f"Memory write task submitted, end_user_id={end_user_id}")
|
logger.info(f"Memory write task submitted, end_user_id={end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
# "task_id": task.id,
|
"task_id": task_id,
|
||||||
"status": "PENDING",
|
"status": "QUEUED",
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user