fix(celery, rag): unify rag_write return type and remove deprecated downstream calls
- Unify the return type of `rag_write` in Celery tasks for consistency. - Remove two deprecated downstream API calls to avoid obsolete dependencies.
This commit is contained in:
@@ -118,142 +118,142 @@ async def download_log(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/writer_service", response_model=ApiResponse)
|
# @router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server(
|
# async def write_server(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Write service endpoint - processes write operations synchronously
|
# Write service endpoint - processes write operations synchronously
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Response with write operation status
|
# Response with write operation status
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
|
#
|
||||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
if storage_type == 'rag':
|
# if storage_type == 'rag':
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge:
|
# if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
# user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
# else:
|
||||||
api_logger.warning(
|
# api_logger.warning(
|
||||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
else:
|
# else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
|
#
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
# result = await memory_agent_service.write_memory(
|
||||||
user_input.end_user_id,
|
# user_input.end_user_id,
|
||||||
messages_list,
|
# messages_list,
|
||||||
config_id,
|
# config_id,
|
||||||
db,
|
# db,
|
||||||
storage_type,
|
# storage_type,
|
||||||
user_rag_memory_id,
|
# user_rag_memory_id,
|
||||||
language
|
# language
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
return success(data=result, msg="写入成功")
|
# return success(data=result, msg="写入成功")
|
||||||
except BaseException as e:
|
# except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
if hasattr(e, 'exceptions'):
|
# if hasattr(e, 'exceptions'):
|
||||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
detailed_error = "; ".join(error_messages)
|
# detailed_error = "; ".join(error_messages)
|
||||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
#
|
||||||
|
#
|
||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
# async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
# Async write service endpoint - enqueues write processing to Celery
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Task ID for tracking async operation
|
# Task ID for tracking async operation
|
||||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
|
#
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
# 获取标准化的消息列表
|
# # 获取标准化的消息列表
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
#
|
||||||
task = celery_app.send_task(
|
# task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
# "app.core.memory.agent.write_message",
|
||||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||||
)
|
# )
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
# api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
#
|
||||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -309,57 +309,21 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 没有字段需要更新时抛出
|
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
|
||||||
|
db_config = db.execute(stmt).scalar_one_or_none()
|
||||||
if not db_config:
|
if not db_config:
|
||||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新字段映射
|
update_data = update.model_dump(exclude_unset=True)
|
||||||
field_mapping = {
|
update_data.pop("config_id", None)
|
||||||
# 模型选择
|
|
||||||
"llm_id": "llm_id",
|
|
||||||
"embedding_id": "embedding_id",
|
|
||||||
"rerank_id": "rerank_id",
|
|
||||||
# 记忆萃取引擎
|
|
||||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
|
||||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
|
||||||
"deep_retrieval": "deep_retrieval",
|
|
||||||
"t_type_strict": "t_type_strict",
|
|
||||||
"t_name_strict": "t_name_strict",
|
|
||||||
"t_overall": "t_overall",
|
|
||||||
"state": "state",
|
|
||||||
"chunker_strategy": "chunker_strategy",
|
|
||||||
# 句子提取
|
|
||||||
"statement_granularity": "statement_granularity",
|
|
||||||
"include_dialogue_context": "include_dialogue_context",
|
|
||||||
"max_context": "max_context",
|
|
||||||
# 剪枝配置
|
|
||||||
"pruning_enabled": "pruning_enabled",
|
|
||||||
"pruning_scene": "pruning_scene",
|
|
||||||
"pruning_threshold": "pruning_threshold",
|
|
||||||
# 自我反思配置
|
|
||||||
"enable_self_reflexion": "enable_self_reflexion",
|
|
||||||
"iteration_period": "iteration_period",
|
|
||||||
"reflexion_range": "reflexion_range",
|
|
||||||
"baseline": "baseline",
|
|
||||||
}
|
|
||||||
|
|
||||||
has_update = False
|
for field, value in update_data.items():
|
||||||
for api_field, db_field in field_mapping.items():
|
setattr(db_config, field, value)
|
||||||
value = getattr(update, api_field, None)
|
|
||||||
if value is not None:
|
|
||||||
setattr(db_config, db_field, value)
|
|
||||||
has_update = True
|
|
||||||
|
|
||||||
if not has_update:
|
|
||||||
raise ValueError("No fields to update")
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_config)
|
db.refresh(db_config)
|
||||||
|
|||||||
@@ -267,8 +267,16 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
async def write_memory(
|
||||||
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
config_id: Optional[uuid.UUID] | int,
|
||||||
|
db: Session,
|
||||||
|
storage_type: str,
|
||||||
|
user_rag_memory_id: str,
|
||||||
|
language: str = "zh"
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
@@ -297,8 +305,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(
|
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
|
||||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
f"Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -338,8 +346,8 @@ class MemoryAgentService:
|
|||||||
if storage_type == "rag":
|
if storage_type == "rag":
|
||||||
# For RAG storage, convert messages to single string
|
# For RAG storage, convert messages to single string
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
return result
|
return "success"
|
||||||
else:
|
else:
|
||||||
async with make_write_graph() as graph:
|
async with make_write_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ async def memory_konwledges_up(
|
|||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||||
|
|
||||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
return db_document
|
||||||
|
|
||||||
|
|
||||||
async def create_document_chunk(
|
async def create_document_chunk(
|
||||||
@@ -350,7 +350,7 @@ async def create_document_chunk(
|
|||||||
create_data: ChunkCreate,
|
create_data: ChunkCreate,
|
||||||
db: Session,
|
db: Session,
|
||||||
current_user: User
|
current_user: User
|
||||||
):
|
) -> DocumentChunk:
|
||||||
"""
|
"""
|
||||||
创建文档块
|
创建文档块
|
||||||
|
|
||||||
@@ -439,10 +439,10 @@ async def create_document_chunk(
|
|||||||
db_document.chunk_num += 1
|
db_document.chunk_num += 1
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return success(data=chunk, msg="文档块创建成功")
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
|
||||||
"""
|
"""
|
||||||
将消息写入 RAG 知识库
|
将消息写入 RAG 知识库
|
||||||
|
|
||||||
@@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||||
print('======', document)
|
print('======', document)
|
||||||
api_logger.info(f"查找文档结果: document_id={document}")
|
api_logger.info(f"查找文档结果: document_id={document}")
|
||||||
|
create_chunks = ChunkCreate(content=message)
|
||||||
if document is not None:
|
if document is not None:
|
||||||
# 文档已存在,直接添加新块
|
# 文档已存在,直接添加新块
|
||||||
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
||||||
|
|
||||||
create_chunks = ChunkCreate(content=message)
|
|
||||||
result = await create_document_chunk(
|
result = await create_document_chunk(
|
||||||
kb_id=kb_uuid,
|
kb_id=kb_uuid,
|
||||||
document_id=uuid.UUID(document),
|
document_id=uuid.UUID(document),
|
||||||
@@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
else:
|
else:
|
||||||
# 文档不存在,创建新文档
|
# 文档不存在,创建新文档
|
||||||
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
||||||
result = await memory_konwledges_up(
|
document = await memory_konwledges_up(
|
||||||
kb_id=user_rag_memory_id,
|
kb_id=user_rag_memory_id,
|
||||||
parent_id=user_rag_memory_id,
|
parent_id=user_rag_memory_id,
|
||||||
create_data=create_data,
|
create_data=create_data,
|
||||||
db=db,
|
db=db,
|
||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
result = await create_document_chunk(
|
||||||
|
kb_id=kb_uuid,
|
||||||
|
document_id=document.id,
|
||||||
|
create_data=create_chunks,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
# 重新查询刚创建的文档ID
|
# 重新查询刚创建的文档ID
|
||||||
new_document_id = find_document_id_by_kb_and_filename(
|
new_document_id = find_document_id_by_kb_and_filename(
|
||||||
db=db,
|
db=db,
|
||||||
|
|||||||
Reference in New Issue
Block a user