fix(celery, rag): unify rag_write return type and remove deprecated downstream calls

- Unify the return type of `rag_write` in Celery tasks for consistency.
- Remove two deprecated downstream API calls to avoid obsolete dependencies.
This commit is contained in:
Eternity
2026-03-20 18:28:21 +08:00
parent c17a2dad2d
commit dce7206c44
4 changed files with 169 additions and 190 deletions

View File

@@ -118,142 +118,142 @@ async def download_log(
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
Returns:
Response with write operation status
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.end_user_id,
messages_list,
config_id,
db,
storage_type,
user_rag_memory_id,
language
)
return success(data=result, msg="写入成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
Returns:
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
# @router.post("/writer_service", response_model=ApiResponse)
# @cur_workspace_access_guard()
# async def write_server(
# user_input: Write_UserInput,
# language_type: str = Header(default=None, alias="X-Language-Type"),
# db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user)
# ):
# """
# Write service endpoint - processes write operations synchronously
#
# Args:
# user_input: Write request containing message and end_user_id
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
# Returns:
# Response with write operation status
# """
# # 使用集中化的语言校验
# language = get_language_from_header(language_type)
#
# config_id = user_input.config_id
# workspace_id = current_user.current_workspace_id
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# # 获取 storage_type如果为 None 则使用默认值
# storage_type = workspace_service.get_workspace_storage_type(
# db=db,
# workspace_id=workspace_id,
# user=current_user
# )
# if storage_type is None: storage_type = 'neo4j'
# user_rag_memory_id = ''
#
# # 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
# if storage_type == 'rag':
# if workspace_id:
# knowledge = knowledge_repository.get_knowledge_by_name(
# db=db,
# name="USER_RAG_MERORY",
# workspace_id=workspace_id
# )
# if knowledge:
# user_rag_memory_id = str(knowledge.id)
# else:
# api_logger.warning(
# f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
# storage_type = 'neo4j'
# else:
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
# storage_type = 'neo4j'
#
# api_logger.info(
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
# try:
# messages_list = memory_agent_service.get_messages_list(user_input)
# result = await memory_agent_service.write_memory(
# user_input.end_user_id,
# messages_list,
# config_id,
# db,
# storage_type,
# user_rag_memory_id,
# language
# )
#
# return success(data=result, msg="写入成功")
# except BaseException as e:
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
# if hasattr(e, 'exceptions'):
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
# detailed_error = "; ".join(error_messages)
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
#
#
# @router.post("/writer_service_async", response_model=ApiResponse)
# @cur_workspace_access_guard()
# async def write_server_async(
# user_input: Write_UserInput,
# language_type: str = Header(default=None, alias="X-Language-Type"),
# db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user)
# ):
# """
# Async write service endpoint - enqueues write processing to Celery
#
# Args:
# user_input: Write request containing message and end_user_id
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
# Returns:
# Task ID for tracking async operation
# Use GET /memory/write_result/{task_id} to check task status and get result
# """
# # 使用集中化的语言校验
# language = get_language_from_header(language_type)
#
# config_id = user_input.config_id
# workspace_id = current_user.current_workspace_id
# api_logger.info(
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# # 获取 storage_type如果为 None 则使用默认值
# storage_type = workspace_service.get_workspace_storage_type(
# db=db,
# workspace_id=workspace_id,
# user=current_user
# )
# if storage_type is None: storage_type = 'neo4j'
# user_rag_memory_id = ''
# if workspace_id:
#
# knowledge = knowledge_repository.get_knowledge_by_name(
# db=db,
# name="USER_RAG_MERORY",
# workspace_id=workspace_id
# )
# if knowledge: user_rag_memory_id = str(knowledge.id)
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# try:
# # 获取标准化的消息列表
# messages_list = memory_agent_service.get_messages_list(user_input)
#
# task = celery_app.send_task(
# "app.core.memory.agent.write_message",
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
# )
# api_logger.info(f"Write task queued: {task.id}")
#
# return success(data={"task_id": task.id}, msg="写入任务已提交")
# except Exception as e:
# api_logger.error(f"Async write operation failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse)