diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index e3d2bf92..aa4d48e3 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -118,142 +118,142 @@ async def download_log( return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) -@router.post("/writer_service", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Write service endpoint - processes write operations synchronously - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Response with write operation status - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - - # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id - if storage_type == 'rag': - if workspace_id: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - else: - api_logger.warning( - f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") - storage_type = 'neo4j' - else: - api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") - storage_type = 'neo4j' - - api_logger.info( - f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") - try: - messages_list = memory_agent_service.get_messages_list(user_input) - result = await memory_agent_service.write_memory( - user_input.end_user_id, - messages_list, - config_id, - db, - storage_type, - user_rag_memory_id, - language - ) - - return success(data=result, msg="写入成功") - except BaseException as e: - # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup - if hasattr(e, 'exceptions'): - error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] - detailed_error = "; ".join(error_messages) - api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) - api_logger.error(f"Write operation error: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) - - -@router.post("/writer_service_async", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Async write service endpoint - enqueues write processing to Celery - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Task ID for tracking async operation - Use GET /memory/write_result/{task_id} to check task status and get result - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info( - f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - try: - # 获取标准化的消息列表 - messages_list = memory_agent_service.get_messages_list(user_input) - - task = celery_app.send_task( - "app.core.memory.agent.write_message", - args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] - ) - api_logger.info(f"Write task queued: {task.id}") - - return success(data={"task_id": task.id}, msg="写入任务已提交") - except Exception as e: - api_logger.error(f"Async write operation failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# @router.post("/writer_service", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Write service endpoint - processes write operations synchronously +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Response with write operation status +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# +# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id +# if storage_type == 'rag': +# if workspace_id: +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: +# user_rag_memory_id = str(knowledge.id) +# else: +# api_logger.warning( +# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") +# storage_type = 'neo4j' +# else: +# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") +# storage_type = 'neo4j' +# +# api_logger.info( +# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") +# try: +# messages_list = memory_agent_service.get_messages_list(user_input) +# result = await memory_agent_service.write_memory( +# user_input.end_user_id, +# messages_list, +# config_id, +# db, +# storage_type, +# user_rag_memory_id, +# language +# ) +# +# return success(data=result, msg="写入成功") +# except BaseException as e: +# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup +# if hasattr(e, 'exceptions'): +# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] +# detailed_error = "; ".join(error_messages) +# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) +# api_logger.error(f"Write operation error: {str(e)}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# +# +# @router.post("/writer_service_async", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server_async( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Async write service endpoint - enqueues write processing to Celery +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Task ID for tracking async operation +# Use GET /memory/write_result/{task_id} to check task status and get result +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info( +# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# if workspace_id: +# +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: user_rag_memory_id = str(knowledge.id) +# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") +# try: +# # 获取标准化的消息列表 +# messages_list = memory_agent_service.get_messages_list(user_input) +# +# task = celery_app.send_task( +# "app.core.memory.agent.write_message", +# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] +# ) +# api_logger.info(f"Write task queued: {task.id}") +# +# return success(data={"task_id": task.id}, msg="写入任务已提交") +# except Exception as e: +# api_logger.error(f"Async write operation failed: {str(e)}") +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) @router.post("/read_service", response_model=ApiResponse) diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22f13449..5c2f81a7 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -309,57 +309,21 @@ class MemoryConfigRepository: Returns: Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None - - Raises: - ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新萃取配置: config_id={update.config_id}") try: - db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() + stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id) + db_config = db.execute(stmt).scalar_one_or_none() if not db_config: db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None - # 更新字段映射 - field_mapping = { - # 模型选择 - "llm_id": "llm_id", - "embedding_id": "embedding_id", - "rerank_id": "rerank_id", - # 记忆萃取引擎 - "enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise", - "enable_llm_disambiguation": "enable_llm_disambiguation", - "deep_retrieval": "deep_retrieval", - "t_type_strict": "t_type_strict", - "t_name_strict": "t_name_strict", - "t_overall": "t_overall", - "state": "state", - "chunker_strategy": "chunker_strategy", - # 句子提取 - "statement_granularity": "statement_granularity", - "include_dialogue_context": "include_dialogue_context", - "max_context": "max_context", - # 剪枝配置 - "pruning_enabled": "pruning_enabled", - "pruning_scene": "pruning_scene", - "pruning_threshold": "pruning_threshold", - # 自我反思配置 - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - } + update_data = update.model_dump(exclude_unset=True) + update_data.pop("config_id", None) - has_update = False - for api_field, db_field in field_mapping.items(): - value = getattr(update, api_field, None) - if value is not None: - setattr(db_config, db_field, value) - has_update = True - - if not has_update: - raise ValueError("No fields to update") + for field, value in update_data.items(): + setattr(db_config, field, value) db.commit() db.refresh(db_config) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index af9a04e2..514cb12f 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -267,8 +267,16 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, - db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory( + self, + end_user_id: str, + messages: list[dict], + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" + ) -> str: """ Process write operation with config_id @@ -297,8 +305,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError( - f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. " + f"Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -338,8 +346,8 @@ class MemoryAgentService: if storage_type == "rag": # For RAG storage, convert messages to single string message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(end_user_id, message_text, user_rag_memory_id) - return result + await write_rag(end_user_id, message_text, user_rag_memory_id) + return "success" else: async with make_write_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index b8961d33..523adadb 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -341,7 +341,7 @@ async def memory_konwledges_up( ) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) - return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") + return db_document async def create_document_chunk( @@ -350,7 +350,7 @@ async def create_document_chunk( create_data: ChunkCreate, db: Session, current_user: User -): +) -> DocumentChunk: """ 创建文档块 @@ -439,10 +439,10 @@ async def create_document_chunk( db_document.chunk_num += 1 db.commit() - return success(data=chunk, msg="文档块创建成功") + return chunk -async def write_rag(end_user_id, message, user_rag_memory_id): +async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk: """ 将消息写入 RAG 知识库 @@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id): document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") print('======', document) api_logger.info(f"查找文档结果: document_id={document}") + create_chunks = ChunkCreate(content=message) if document is not None: # 文档已存在,直接添加新块 api_logger.info(f"文档已存在,添加新块: document_id={document}") - create_chunks = ChunkCreate(content=message) result = await create_document_chunk( kb_id=kb_uuid, document_id=uuid.UUID(document), @@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: # 文档不存在,创建新文档 api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") - result = await memory_konwledges_up( + document = await memory_konwledges_up( kb_id=user_rag_memory_id, parent_id=user_rag_memory_id, create_data=create_data, db=db, current_user=current_user ) + result = await create_document_chunk( + kb_id=kb_uuid, + document_id=document.id, + create_data=create_chunks, + db=db, + current_user=current_user + ) # 重新查询刚创建的文档ID new_document_id = find_document_id_by_kb_and_filename( db=db,