From 1aff4eda67f3e8789bc8784fba3a7bf1f32c86a0 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Mon, 2 Feb 2026 20:31:45 +0800 Subject: [PATCH 01/45] memory_BUG_fix --- api/app/core/memory/agent/langgraph_graph/write_graph.py | 3 +-- api/app/services/draft_run_service.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 5101fa29..9547c866 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -49,7 +49,7 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ db_session = next(get_db()) config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id="08ed205c-0f05-49c3-8e0c-a580d28f5fd4", # 改为整数 + config_id=memory_config, # 改为整数 service_name="MemoryAgentService" ) if long_term_type=='chunk': @@ -59,7 +59,6 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ """时间""" await memory_long_term_storage(end_user_id, memory_config,5) if long_term_type=='aggregate': - """方案三:聚合判断""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 9a3e1d37..43073555 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str result = task_service.get_task_memory_read_result(task.id) status = result.get("status") logger.info(f"读取任务状态:{status}") + if memory_content: + memory_content = memory_content['answer'] finally: db.close() @@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "content_length": len(str(memory_content)) } ) - return f"检索到以下历史记忆:\n\n{memory_content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) From 88c95db8d0e6518de634700a0b91137a2a9671da Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 12:19:00 +0800 Subject: [PATCH 02/45] [add]The main project adds multi-API Key load balancing. --- api/app/controllers/ontology_controller.py | 35 ++++++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 43d3b1d2..895b6d40 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -190,19 +190,48 @@ def _get_ontology_service( detail="指定的LLM模型没有配置API密钥" ) - api_key_config = model_config.api_keys[0] + # 获取可用的 API Key(只选择激活状态的) + active_api_keys = [ak for ak in model_config.api_keys if ak.is_active] + if not active_api_keys: + logger.error(f"Model {llm_id} has no active API key") + raise HTTPException( + status_code=400, + detail="指定的LLM模型没有可用的API密钥" + ) + + # 对于组合模型,根据负载均衡策略选择 API Key + if model_config.is_composite and len(active_api_keys) > 1: + from app.models.models_model import LoadBalanceStrategy + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + # 轮询策略:选择使用次数最少的 API Key + api_key_config = min(active_api_keys, key=lambda x: int(x.usage_count or "0")) + else: + # 默认策略:按优先级选择 + api_key_config = min(active_api_keys, key=lambda x: int(x.priority or "1")) + logger.info( + f"Composite model using load balance strategy: {model_config.load_balance_strategy}, " + f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" + ) + else: + api_key_config = active_api_keys[0] logger.info( f"Using specified model - user: {current_user.id}, " - f"model_id: {llm_id}, model_name: {api_key_config.model_name}" + f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " + f"is_composite: {model_config.is_composite}" ) # 创建模型配置对象 from app.core.models.base import RedBearModelConfig + # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider + actual_provider = api_key_config.provider if model_config.is_composite else ( + model_config.provider if hasattr(model_config, 'provider') else "openai" + ) + llm_model_config = RedBearModelConfig( model_name=api_key_config.model_name, - provider=model_config.provider if hasattr(model_config, 'provider') else "openai", + provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, max_retries=3, From ffff138a6ff4714e01feec563a23b631d016eec2 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 12:28:05 +0800 Subject: [PATCH 03/45] [changes]Attribute security access, secure numerical conversion, unified use of local variables --- api/app/controllers/ontology_controller.py | 32 +++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 895b6d40..1c22529b 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -191,7 +191,7 @@ def _get_ontology_service( ) # 获取可用的 API Key(只选择激活状态的) - active_api_keys = [ak for ak in model_config.api_keys if ak.is_active] + active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', True)] if not active_api_keys: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( @@ -199,17 +199,29 @@ def _get_ontology_service( detail="指定的LLM模型没有可用的API密钥" ) + # 安全的数值转换辅助函数 + def safe_int(value, default: int = 0) -> int: + """安全地将值转换为整数,异常时返回默认值""" + if value is None: + return default + try: + return int(value) + except (ValueError, TypeError): + return default + # 对于组合模型,根据负载均衡策略选择 API Key - if model_config.is_composite and len(active_api_keys) > 1: + is_composite = getattr(model_config, 'is_composite', False) + if is_composite and len(active_api_keys) > 1: from app.models.models_model import LoadBalanceStrategy - if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + load_balance_strategy = getattr(model_config, 'load_balance_strategy', None) + if load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: # 轮询策略:选择使用次数最少的 API Key - api_key_config = min(active_api_keys, key=lambda x: int(x.usage_count or "0")) + api_key_config = min(active_api_keys, key=lambda x: safe_int(x.usage_count, 0)) else: - # 默认策略:按优先级选择 - api_key_config = min(active_api_keys, key=lambda x: int(x.priority or "1")) + # 默认策略:按优先级选择(优先级数值越小越优先) + api_key_config = min(active_api_keys, key=lambda x: safe_int(x.priority, 1)) logger.info( - f"Composite model using load balance strategy: {model_config.load_balance_strategy}, " + f"Composite model using load balance strategy: {load_balance_strategy}, " f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" ) else: @@ -218,15 +230,15 @@ def _get_ontology_service( logger.info( f"Using specified model - user: {current_user.id}, " f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " - f"is_composite: {model_config.is_composite}" + f"is_composite: {is_composite}" ) # 创建模型配置对象 from app.core.models.base import RedBearModelConfig # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider - actual_provider = api_key_config.provider if model_config.is_composite else ( - model_config.provider if hasattr(model_config, 'provider') else "openai" + actual_provider = api_key_config.provider if is_composite else ( + getattr(model_config, 'provider', None) or "openai" ) llm_model_config = RedBearModelConfig( From 34f0c3b90c06227d9220e322df41a1e7d5b5feb2 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 13:44:07 +0800 Subject: [PATCH 04/45] [changes]Active status filtering logic, API Key selection strategy --- api/app/controllers/ontology_controller.py | 25 +++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 1c22529b..8ad47bc1 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -190,15 +190,6 @@ def _get_ontology_service( detail="指定的LLM模型没有配置API密钥" ) - # 获取可用的 API Key(只选择激活状态的) - active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', True)] - if not active_api_keys: - logger.error(f"Model {llm_id} has no active API key") - raise HTTPException( - status_code=400, - detail="指定的LLM模型没有可用的API密钥" - ) - # 安全的数值转换辅助函数 def safe_int(value, default: int = 0) -> int: """安全地将值转换为整数,异常时返回默认值""" @@ -209,9 +200,19 @@ def _get_ontology_service( except (ValueError, TypeError): return default - # 对于组合模型,根据负载均衡策略选择 API Key + # 获取可用的 API Key(只选择激活状态的) + # 注意:is_active 为 None 时视为非激活状态,避免旧记录被错误地当作激活 + active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', None) is True] + if not active_api_keys: + logger.error(f"Model {llm_id} has no active API key") + raise HTTPException( + status_code=400, + detail="指定的LLM模型没有可用的API密钥" + ) + + # 根据负载均衡策略选择 API Key(组合模型和非组合模型统一处理) is_composite = getattr(model_config, 'is_composite', False) - if is_composite and len(active_api_keys) > 1: + if len(active_api_keys) > 1: from app.models.models_model import LoadBalanceStrategy load_balance_strategy = getattr(model_config, 'load_balance_strategy', None) if load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: @@ -221,7 +222,7 @@ def _get_ontology_service( # 默认策略:按优先级选择(优先级数值越小越优先) api_key_config = min(active_api_keys, key=lambda x: safe_int(x.priority, 1)) logger.info( - f"Composite model using load balance strategy: {load_balance_strategy}, " + f"Model (is_composite={is_composite}) using load balance strategy: {load_balance_strategy}, " f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" ) else: From c8c7e9b3048bd236ca47c8c28e50ccf56adae16b Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 13:45:10 +0800 Subject: [PATCH 05/45] memory_BUG --- api/app/core/agent/langchain_agent.py | 136 ++----------- .../langgraph_graph/routing/write_router.py | 178 ++++++++++++------ .../agent/langgraph_graph/tools/write_tool.py | 34 +--- .../agent/langgraph_graph/write_graph.py | 20 +- api/app/schemas/memory_agent_schema.py | 14 +- 5 files changed, 167 insertions(+), 215 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 7e0015ae..e4204e83 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,14 +11,17 @@ import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence +from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage +from app.core.memory.agent.utils.write_tools import write from app.db import get_db from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_agent_service import ( get_end_user_connected_config, ) @@ -148,106 +151,6 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - # TODO: 移到memory module - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): - db = next(get_db()) - #TODO: 魔法数字 - scope=6 - - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) - - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - - # Handle case where no session exists in Redis (returns False) - if not result or result is False: - logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") - return - - if type=="chunk" or type=="aggregate": - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'写入短长期:') - else: - # TODO: This branch handles type="time" strategy, currently unused. - # Will be activated when time-based long-term storage is implemented. - # TODO: 魔法数字 - extract 5 to a constant - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - logger.debug(f"No recent sessions in Redis for user {end_user_id}") - return - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() - - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -321,14 +224,14 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - long_term_messages=await agent_chat_messages(message_chat,content) - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + if storage_type == "rag": + await write_rag(end_user_id, message_chat, content, user_rag_memory_id) + else: + long_term_messages=await agent_chat_messages(message_chat,content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await long_term_storage(long_term_type="chunk",langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=2) + '''长期''' + await term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, "model": self.model_name, @@ -459,14 +362,15 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - long_term_messages = await agent_chat_messages(message_chat, full_content) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag(end_user_id, message_chat, full_content, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK=AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE=AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, full_content) + await long_term_storage(long_term_type=CHUNK,langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK,scope=SCOPE) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index e9de02b6..ab65caa7 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -1,8 +1,9 @@ +import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ @@ -10,46 +11,111 @@ from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id + logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') +async def write_rag(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) -async def write_messages(end_user_id,langchain_messages,memory_config): - ''' - 写入数据到neo4j: - Args: + Args: + storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - ''' + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + db = next(get_db()) + try: + repo = LongTermMemoryRepository(db) + await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + finally: + db.close() - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - # TODO:删除 - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - print(contents) - except Exception as e: - import traceback - traceback.print_exc() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): ''' @@ -61,25 +127,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope - redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] redis_messages = count_store.get_sessions_count(end_user_id)[1] if is_end_user_id and int(is_end_user_id) != int(scope): - print(is_end_user_id) is_end_user_id += 1 langchain_messages += redis_messages count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): - print('写入长期记忆,并且设置为0') - print(is_end_user_id) - formatted_messages = await chat_data_format(redis_messages) - print(100*'-') - print(formatted_messages) - print(100*'-') - await write_messages(end_user_id, formatted_messages, memory_config) - count_store.update_sessions_count(end_user_id, 0, '') + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) @@ -93,12 +160,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - return - format_messages = await chat_data_format(long_time_data) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message if format_messages!=[]: - await write_messages(end_user_id, format_messages, memory_config) + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) '''聚合判断''' async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ @@ -109,13 +179,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: 内存配置对象 """ - + try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - - # Handle case where no session exists in Redis (returns False or empty) - if not result or result is False: + history = await format_parsing(result) + if not result: history = [] else: history = await format_parsing(result) @@ -154,7 +223,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config } if not structured.is_same_event: logger.info(result_dict) - await write_messages(end_user_id, output_value, memory_config) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) return result_dict except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index a1fb8226..d0be8e5c 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human': + if role == 'human' or role=="user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict": - if role == 'human': + if type == "dict" : + if role == 'human' or role=="user": user.append( content) else: ai.append(content) @@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict): for key, values in zip(user, ai): database.append({key, values}) return database -async def chat_data_format(messages: list | dict): - """ - 将消息格式化为 LangChain 消息格式 - - Args: - messages: 消息列表或字典 - - Returns: - LangChain 消息列表 - """ - langchain_messages = [] - if isinstance(messages, list): - for msg in messages: - if 'role' in msg.keys(): - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - if "Query" in msg.keys(): - langchain_messages.append(HumanMessage(content=msg['Query'])) - langchain_messages.append(AIMessage(content=msg['Answer'])) - if isinstance(messages, dict): - if messages['type'] == 'human': - langchain_messages.append(HumanMessage(content=messages['content'])) - elif messages['type'] == 'ai': - langchain_messages.append(AIMessage(content=messages['content'])) - return langchain_messages + async def agent_chat_messages(user_content,ai_content): messages = [ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 9547c866..64a1296c 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState @@ -42,9 +42,8 @@ async def make_write_graph(): async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment - from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + write_store.save_session_write(end_user_id, (langchain_messages)) # 获取数据库会话 db_session = next(get_db()) config_service = MemoryConfigService(db_session) @@ -62,31 +61,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ """方案三:聚合判断""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) -# + # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ # { # "role": "user", -# "content": "今天周五好开心啊" +# "content": "今天周五去爬山" # }, # { # "role": "assistant", -# "content": "你也这么觉得,我也是耶" +# "content": "好耶" # } # # ] # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# from app.core.memory.agent.utils.redis_tool import write_store -# result=write_store.get_session_by_userid(end_user_id) -# data=await format_parsing(result,"dict") -# chunk_data=data[:6] +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) # -# long_time_data = write_store.find_user_recent_sessions(end_user_id, 240) -# long_=await messages_parse(long_time_data) -# print(long_) # # # if __name__ == "__main__": diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + + From 2d28b4b05cb42aecdf895cc3eebac820967c70d7 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 13:54:32 +0800 Subject: [PATCH 06/45] memory_BUG_long_term --- api/app/core/agent/langchain_agent.py | 36 +++---------------- .../agent/langgraph_graph/write_graph.py | 18 +++++++++- 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index e4204e83..e519ea53 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,33 +7,21 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os + import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save -from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage -from app.core.memory.agent.utils.write_tools import write +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger -from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.repositories.memory_short_repository import LongTermMemoryRepository -from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_agent_service import ( get_end_user_connected_config, ) -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool - -from app.utils.config_utils import resolve_config_id - logger = get_business_logger() @@ -224,14 +212,7 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - if storage_type == "rag": - await write_rag(end_user_id, message_chat, content, user_rag_memory_id) - else: - long_term_messages=await agent_chat_messages(message_chat,content) - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) - await long_term_storage(long_term_type="chunk",langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=2) - '''长期''' - await term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) response = { "content": content, "model": self.model_name, @@ -362,16 +343,7 @@ class LangChainAgent: yield total_tokens break if memory_flag: - if storage_type == AgentMemory_Long_Term.STORAGE_RAG: - await write_rag(end_user_id, message_chat, full_content, user_rag_memory_id) - else: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) - CHUNK=AgentMemory_Long_Term.STRATEGY_CHUNK - SCOPE=AgentMemory_Long_Term.DEFAULT_SCOPE - long_term_messages = await agent_chat_messages(message_chat, full_content) - await long_term_storage(long_term_type=CHUNK,langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=SCOPE) - await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK,scope=SCOPE) - + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 64a1296c..b788d1ec 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -7,13 +7,14 @@ from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService + warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -62,6 +63,21 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ await aggregate_judgment(end_user_id, langchain_messages, memory_config) +async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + from app.services.memory_konwledges_server import write_rag + from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag(end_user_id, message_chat, aimessages, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, aimessages) + await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) + # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ From 333836f5e796cfdca9977fe19c8790bbf5ebc768 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 14:08:09 +0800 Subject: [PATCH 07/45] [changes] --- api/app/controllers/ontology_controller.py | 46 +++------------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 8ad47bc1..6520d835 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -182,56 +182,22 @@ def _get_ontology_service( detail=f"找不到指定的LLM模型: {llm_id}" ) - # 验证模型配置了API密钥 - if not model_config.api_keys: - logger.error(f"Model {llm_id} has no API key configuration") - raise HTTPException( - status_code=400, - detail="指定的LLM模型没有配置API密钥" - ) - - # 安全的数值转换辅助函数 - def safe_int(value, default: int = 0) -> int: - """安全地将值转换为整数,异常时返回默认值""" - if value is None: - return default - try: - return int(value) - except (ValueError, TypeError): - return default - - # 获取可用的 API Key(只选择激活状态的) - # 注意:is_active 为 None 时视为非激活状态,避免旧记录被错误地当作激活 - active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', None) is True] - if not active_api_keys: + # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) + from app.repositories.model_repository import ModelApiKeyRepository + api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + if not api_keys: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, detail="指定的LLM模型没有可用的API密钥" ) + api_key_config = api_keys[0] - # 根据负载均衡策略选择 API Key(组合模型和非组合模型统一处理) is_composite = getattr(model_config, 'is_composite', False) - if len(active_api_keys) > 1: - from app.models.models_model import LoadBalanceStrategy - load_balance_strategy = getattr(model_config, 'load_balance_strategy', None) - if load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: - # 轮询策略:选择使用次数最少的 API Key - api_key_config = min(active_api_keys, key=lambda x: safe_int(x.usage_count, 0)) - else: - # 默认策略:按优先级选择(优先级数值越小越优先) - api_key_config = min(active_api_keys, key=lambda x: safe_int(x.priority, 1)) - logger.info( - f"Model (is_composite={is_composite}) using load balance strategy: {load_balance_strategy}, " - f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" - ) - else: - api_key_config = active_api_keys[0] - logger.info( f"Using specified model - user: {current_user.id}, " f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " - f"is_composite: {is_composite}" + f"is_composite: {is_composite}, api_key_id: {api_key_config.id}" ) # 创建模型配置对象 From 62aba2dd38907bdba5bb8e5b34036455c2e16168 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 14:21:49 +0800 Subject: [PATCH 08/45] memory_BUG_long_term --- .../langgraph_graph/routing/write_router.py | 46 ++++++++++--------- .../agent/langgraph_graph/write_graph.py | 4 +- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index ab65caa7..29257e88 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -22,7 +22,7 @@ from app.utils.config_utils import resolve_config_id logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -async def write_rag(end_user_id, user_message, ai_message, user_rag_memory_id): +async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): # RAG 模式:组合消息为字符串格式(保持原有逻辑) combined_message = f"user: {user_message}\nassistant: {ai_message}" await write_rag(end_user_id, combined_message, user_rag_memory_id) @@ -94,27 +94,31 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me db.close() async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): - db = next(get_db()) - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + with get_db_context() as db_session: + try: + repo = LongTermMemoryRepository(db_session) + await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + # yield db_session + finally: + if db_session.in_transaction(): + db_session.rollback() + db_session.close() - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'---------写入短长期-----------') - else: - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index b788d1ec..97f894f7 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -64,11 +64,11 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): - from app.services.memory_konwledges_server import write_rag + from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: - await write_rag(end_user_id, message_chat, aimessages, user_rag_memory_id) + await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) else: # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK From 72b5e5cf8e529c89668d679fc1c78c5f2f8bcd22 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 14:24:50 +0800 Subject: [PATCH 09/45] memory_BUG_long_term --- .../agent/langgraph_graph/write_graph.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 97f894f7..c0e6f86e 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.db import get_db +from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node @@ -46,21 +46,27 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ from app.core.memory.agent.utils.redis_tool import write_store write_store.save_session_write(end_user_id, (langchain_messages)) # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 - service_name="MemoryAgentService" - ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': - """方案三:聚合判断""" - await aggregate_judgment(end_user_id, langchain_messages, memory_config) + with get_db_context() as db_session: + try: + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" + ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) + finally: + if db_session.in_transaction(): + db_session.rollback() + db_session.close() + async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): From 8f0a1d9c6e18349e1d819c4d60b5861492f622de Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:34:00 +0800 Subject: [PATCH 10/45] Fix/release memory bug (#306) * memory_BUG_fix * memory_BUG * memory_BUG_long_term * memory_BUG_long_term * memory_BUG_long_term --- api/app/core/agent/langchain_agent.py | 132 +------------ .../langgraph_graph/routing/write_router.py | 182 ++++++++++++------ .../agent/langgraph_graph/tools/write_tool.py | 34 +--- .../agent/langgraph_graph/write_graph.py | 105 +++++----- api/app/schemas/memory_agent_schema.py | 14 +- api/app/services/draft_run_service.py | 3 +- 6 files changed, 202 insertions(+), 268 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 7e0015ae..e519ea53 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,30 +7,21 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os + import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger -from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.repositories.memory_short_repository import LongTermMemoryRepository from app.services.memory_agent_service import ( get_end_user_connected_config, ) -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool - -from app.utils.config_utils import resolve_config_id - logger = get_business_logger() @@ -148,106 +139,6 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - # TODO: 移到memory module - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): - db = next(get_db()) - #TODO: 魔法数字 - scope=6 - - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) - - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - - # Handle case where no session exists in Redis (returns False) - if not result or result is False: - logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") - return - - if type=="chunk" or type=="aggregate": - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'写入短长期:') - else: - # TODO: This branch handles type="time" strategy, currently unused. - # Will be activated when time-based long-term storage is implemented. - # TODO: 魔法数字 - extract 5 to a constant - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - logger.debug(f"No recent sessions in Redis for user {end_user_id}") - return - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() - - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -321,14 +212,7 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - long_term_messages=await agent_chat_messages(message_chat,content) - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) response = { "content": content, "model": self.model_name, @@ -459,15 +343,7 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - long_term_messages = await agent_chat_messages(message_chat, full_content) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") - + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index e9de02b6..29257e88 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -1,8 +1,9 @@ +import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ @@ -10,46 +11,115 @@ from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id + logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') +async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) -async def write_messages(end_user_id,langchain_messages,memory_config): - ''' - 写入数据到neo4j: - Args: + Args: + storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - ''' + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + with get_db_context() as db_session: + try: + repo = LongTermMemoryRepository(db_session) + await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + # yield db_session + finally: + if db_session.in_transaction(): + db_session.rollback() + db_session.close() - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - # TODO:删除 - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - print(contents) - except Exception as e: - import traceback - traceback.print_exc() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): ''' @@ -61,25 +131,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope - redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] redis_messages = count_store.get_sessions_count(end_user_id)[1] if is_end_user_id and int(is_end_user_id) != int(scope): - print(is_end_user_id) is_end_user_id += 1 langchain_messages += redis_messages count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): - print('写入长期记忆,并且设置为0') - print(is_end_user_id) - formatted_messages = await chat_data_format(redis_messages) - print(100*'-') - print(formatted_messages) - print(100*'-') - await write_messages(end_user_id, formatted_messages, memory_config) - count_store.update_sessions_count(end_user_id, 0, '') + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) @@ -93,12 +164,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - return - format_messages = await chat_data_format(long_time_data) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message if format_messages!=[]: - await write_messages(end_user_id, format_messages, memory_config) + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) '''聚合判断''' async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ @@ -109,13 +183,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: 内存配置对象 """ - + try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - - # Handle case where no session exists in Redis (returns False or empty) - if not result or result is False: + history = await format_parsing(result) + if not result: history = [] else: history = await format_parsing(result) @@ -154,7 +227,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config } if not structured.is_same_event: logger.info(result_dict) - await write_messages(end_user_id, output_value, memory_config) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) return result_dict except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index a1fb8226..d0be8e5c 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human': + if role == 'human' or role=="user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict": - if role == 'human': + if type == "dict" : + if role == 'human' or role=="user": user.append( content) else: ai.append(content) @@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict): for key, values in zip(user, ai): database.append({key, values}) return database -async def chat_data_format(messages: list | dict): - """ - 将消息格式化为 LangChain 消息格式 - - Args: - messages: 消息列表或字典 - - Returns: - LangChain 消息列表 - """ - langchain_messages = [] - if isinstance(messages, list): - for msg in messages: - if 'role' in msg.keys(): - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - if "Query" in msg.keys(): - langchain_messages.append(HumanMessage(content=msg['Query'])) - langchain_messages.append(AIMessage(content=msg['Answer'])) - if isinstance(messages, dict): - if messages['type'] == 'human': - langchain_messages.append(HumanMessage(content=messages['content'])) - elif messages['type'] == 'ai': - langchain_messages.append(AIMessage(content=messages['content'])) - return langchain_messages + async def agent_chat_messages(user_content,ai_content): messages = [ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 84ea9381..c0e6f86e 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,14 +1,19 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph +from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_config_service import MemoryConfigService + warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -35,75 +40,67 @@ async def make_write_graph(): graph = workflow.compile() yield graph -async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): - """Dispatch long-term memory storage to Celery background tasks. - - Args: - long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' - langchain_messages: List of messages to store - memory_config: Memory configuration ID (string) - end_user_id: End user identifier - scope: Window size for 'chunk' strategy (default: 6) - """ - from app.tasks import ( - long_term_storage_window_task, - # TODO: Uncomment when implemented - # long_term_storage_time_task, - # long_term_storage_aggregate_task, - ) - from app.core.logging_config import get_logger - - logger = get_logger(__name__) - - # Convert config to string if needed - config_id = str(memory_config) if memory_config else '' - - if long_term_type == 'chunk': - # Strategy 1: Window-based batching (6 rounds of dialogue) - logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") - long_term_storage_window_task.delay( - end_user_id=end_user_id, - langchain_messages=langchain_messages, - config_id=config_id, - scope=scope - ) - # TODO: Uncomment when time-based strategy is fully implemented - # elif long_term_type == 'time': - # # Strategy 2: Time-based retrieval - # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") - # long_term_storage_time_task.delay( - # end_user_id=end_user_id, - # config_id=config_id, - # time_window=5 - # ) - # TODO: Uncomment when aggregate strategy is fully implemented - # elif long_term_type == 'aggregate': - # # Strategy 3: Aggregate judgment (deduplication) - # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") - # long_term_storage_aggregate_task.delay( - # end_user_id=end_user_id, - # langchain_messages=langchain_messages, - # config_id=config_id - # ) +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, (langchain_messages)) + # 获取数据库会话 + with get_db_context() as db_session: + try: + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" + ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) + finally: + if db_session.in_transaction(): + db_session.rollback() + db_session.close() + + + +async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent + from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, aimessages) + await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ # { # "role": "user", -# "content": "今天周五好开心啊" +# "content": "今天周五去爬山" # }, # { # "role": "assistant", -# "content": "你也这么觉得,我也是耶" +# "content": "好耶" # } # # ] # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# # # # if __name__ == "__main__": diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + + diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 9a3e1d37..43073555 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str result = task_service.get_task_memory_read_result(task.id) status = result.get("status") logger.info(f"读取任务状态:{status}") + if memory_content: + memory_content = memory_content['answer'] finally: db.close() @@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "content_length": len(str(memory_content)) } ) - return f"检索到以下历史记忆:\n\n{memory_content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) From 41550d4a416dce4a8ecec8711117eedbd472aaa9 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 15:44:26 +0800 Subject: [PATCH 11/45] knowledge_retrieval/bug/fix --- api/app/core/rag/nlp/search.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 1f696c98..774c7036 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -62,7 +62,15 @@ def knowledge_retrieval( merge_strategy = config.get("merge_strategy", "weight") reranker_id = config.get("reranker_id") reranker_top_k = config.get("reranker_top_k", 1024) - use_graph = config.get("use_graph", "false").lower() == "true" + # use_graph = config.get("use_graph", "false").lower() == "true" + + use_graph_value = config.get("use_graph", False) + if isinstance(use_graph_value, bool): + use_graph = use_graph_value + elif isinstance(use_graph_value, str): + use_graph = use_graph_value.lower() in ("true", "1", "yes") + else: + use_graph = False file_names_filter = [] if user_ids: @@ -159,13 +167,23 @@ def knowledge_retrieval( # Use the specified reranker for re-ranking if reranker_id: - return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + try: + return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + except Exception as rerank_error: + # If reranker fails, log warning and continue with original results + print(f"Failed to rerank documents: {str(rerank_error)}") + print(f"Continuing with original retrieval results (count: {len(all_results)})") + # use graph if use_graph: - from app.core.rag.common.settings import kg_retriever - doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) - if doc: - all_results.insert(0, doc) + try: + from app.core.rag.common.settings import kg_retriever + doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) + if doc: + all_results.insert(0, doc) + except Exception as graph_error: + print(f"Failed to retrieve from knowledge graph: {str(graph_error)}") + return all_results except Exception as e: From 514c19a247df91d2f72a0f82a7a60a026ec19031 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 15:51:13 +0800 Subject: [PATCH 12/45] knowledge_retrieval/bug/fix --- api/app/core/rag/nlp/search.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 774c7036..572f2e3c 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.llm.chat_model import Base from app.core.rag.llm.embedding_model import OpenAIEmbed +import logging +logger = logging.getLogger(__name__) def knowledge_retrieval( query: str, @@ -171,10 +173,16 @@ def knowledge_retrieval( return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) except Exception as rerank_error: # If reranker fails, log warning and continue with original results - print(f"Failed to rerank documents: {str(rerank_error)}") - print(f"Continuing with original retrieval results (count: {len(all_results)})") - - # use graph + logger.warning( + "Reranker failed, falling back to original results", + extra={ + "reranker_id": reranker_id, + "query": query, + "doc_count": len(all_results), + "error": str(e), + }, + ) + if use_graph: try: from app.core.rag.common.settings import kg_retriever From 7922fc3b0e166959aafc8cc62f670cd56adda0aa Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 15:53:13 +0800 Subject: [PATCH 13/45] knowledge_retrieval/bug/fix --- api/app/core/rag/nlp/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 572f2e3c..65fbd9cb 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -179,7 +179,7 @@ def knowledge_retrieval( "reranker_id": reranker_id, "query": query, "doc_count": len(all_results), - "error": str(e), + "error": str(rerank_error), }, ) From d0ddf288ca4a7acc6d92cb0d8abb45e1d5bd1c42 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 17:10:35 +0800 Subject: [PATCH 14/45] [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios --- api/app/controllers/ontology_controller.py | 41 +++++++++++++++++ .../repositories/memory_config_repository.py | 21 ++++++--- .../repositories/ontology_scene_repository.py | 45 +++++++++++++++++++ api/app/schemas/memory_storage_schema.py | 5 ++- api/app/services/memory_storage_service.py | 7 +-- 5 files changed, 107 insertions(+), 12 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 6520d835..3faa889b 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -766,6 +766,47 @@ async def delete_scene( return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) +@router.get("/scenes/simple", response_model=ApiResponse) +async def get_scenes_simple( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景简单列表(轻量级,用于下拉选择) + + 仅返回 scene_id 和 scene_name,不加载关联数据,响应速度快。 + 适用于前端下拉选择场景的场景。 + + Args: + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含场景简单列表 + + Examples: + GET /scenes/simple + 返回: {"items": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]} + """ + api_logger.info(f"Simple scene list requested by user {current_user.id}") + + try: + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + from app.repositories.ontology_scene_repository import OntologySceneRepository + repo = OntologySceneRepository(db) + scenes = repo.get_simple_list(workspace_id) + + api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes") + return success(data={"items": scenes}, msg="查询成功") + + except Exception as e: + api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + @router.get("/scenes", response_model=ApiResponse) async def get_scenes( workspace_id: Optional[str] = None, diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22972669..e846e20c 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -279,6 +279,9 @@ class MemoryConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True + if hasattr(update, 'scene_id') and update.scene_id is not None: + db_config.scene_id = update.scene_id + has_update = True if not has_update: raise ValueError("No fields to update") @@ -650,28 +653,32 @@ class MemoryConfigRepository: raise @staticmethod - def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: - """获取所有配置参数 + def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]: + """获取所有配置参数,包含关联的场景名称 Args: db: 数据库会话 workspace_id: 工作空间ID,用于过滤查询结果 Returns: - List[MemoryConfig]: 配置列表 + List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ + from app.models.ontology_scene import OntologyScene + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: - query = db.query(MemoryConfig) + query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin( + OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id + ) if workspace_id: query = query.filter(MemoryConfig.workspace_id == workspace_id) - configs = query.order_by(desc(MemoryConfig.updated_at)).all() + results = query.order_by(desc(MemoryConfig.updated_at)).all() - db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") - return configs + db_logger.debug(f"配置列表查询成功: 数量={len(results)}") + return results except Exception as e: db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py index 322e111c..141b5d1c 100644 --- a/api/app/repositories/ontology_scene_repository.py +++ b/api/app/repositories/ontology_scene_repository.py @@ -392,3 +392,48 @@ class OntologySceneRepository: exc_info=True ) raise + + def get_simple_list(self, workspace_id: UUID) -> List[dict]: + """获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择) + + 这是一个轻量级查询,不加载关联的classes,响应速度快。 + + Args: + workspace_id: 工作空间ID + + Returns: + List[dict]: 场景简单列表,每项包含scene_id和scene_name + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes = repo.get_simple_list(workspace_id) + >>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...] + """ + try: + logger.debug(f"Getting simple scene list for workspace: {workspace_id}") + + # 只查询需要的字段,不加载关联数据 + results = self.db.query( + OntologyScene.scene_id, + OntologyScene.scene_name + ).filter( + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ).all() + + scenes = [ + {"scene_id": str(r.scene_id), "scene_name": r.scene_name} + for r in results + ] + + logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}") + + return scenes + + except Exception as e: + logger.error( + f"Failed to get simple scene list: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 11cacda0..c3e7295b 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 config_id: Union[uuid.UUID, int, str] = None - config_name: str = Field("配置名称", description="配置名称(字符串)") - config_desc: str = Field("配置描述", description="配置描述(字符串)") + config_name: Optional[str] = Field(None, description="配置名称(字符串)") + config_desc: Optional[str] = Field(None, description="配置描述(字符串)") + scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID") class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 741199c6..7ccd145c 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Read All --- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 - configs = MemoryConfigRepository.get_all(self.db, workspace_id) + results = MemoryConfigRepository.get_all(self.db, workspace_id) # 将 ORM 对象转换为字典列表 data_list = [] - for config in configs: + for config, scene_name in results: # 安全地转换 user_id 为 int config_id_old = None if config.config_id_old: @@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "end_user_id": config.end_user_id, "config_id_old": config_id_old, "apply_id": config.apply_id, - "scene_id": config.scene_id, + "scene_id": str(config.scene_id) if config.scene_id else None, + "scene_name": scene_name, # 新增:场景名称 "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, From 02714546713f422dfa9018756c4d8a118e695b32 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Feb 2026 17:21:04 +0800 Subject: [PATCH 15/45] fix(web): replace code editor --- web/package.json | 10 + web/src/components/CodeMirrorEditor/index.tsx | 150 +++++++++++++++ web/src/styles/index.css | 5 + .../Workflow/components/Editor/index.tsx | 12 +- .../plugin/JavaScriptHighlightPlugin.tsx | 182 ------------------ .../Editor/plugin/Python3HighlightPlugin.tsx | 177 ----------------- .../Properties/CodeExecution/index.tsx | 7 +- 7 files changed, 174 insertions(+), 369 deletions(-) create mode 100644 web/src/components/CodeMirrorEditor/index.tsx delete mode 100644 web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx delete mode 100644 web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx diff --git a/web/package.json b/web/package.json index e28e8b56..89800fcf 100644 --- a/web/package.json +++ b/web/package.json @@ -13,6 +13,14 @@ "@antv/layout": "^1.2.14-beta.8", "@antv/x6": "^3.0.1", "@antv/x6-react-shape": "^3.0.1", + "@codemirror/lang-cpp": "^6.0.3", + "@codemirror/lang-java": "^6.0.2", + "@codemirror/lang-javascript": "^6.2.4", + "@codemirror/lang-python": "^6.2.1", + "@codemirror/lang-rust": "^6.0.2", + "@codemirror/state": "^6.5.4", + "@codemirror/theme-one-dark": "^6.1.3", + "@codemirror/view": "^6.39.12", "@dnd-kit/core": "^6.3.1", "@dnd-kit/modifiers": "^9.0.0", "@dnd-kit/sortable": "^10.0.0", @@ -25,6 +33,7 @@ "antd": "^5.27.4", "axios": "^1.12.2", "clsx": "^2.1.1", + "codemirror": "^6.0.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", "dayjs": "^1.11.18", @@ -55,6 +64,7 @@ "@tailwindcss/postcss": "^4.1.14", "@tailwindcss/typography": "^0.5.19", "@tailwindcss/vite": "^4.1.14", + "@types/codemirror": "^5.60.17", "@types/crypto-js": "^4.2.2", "@types/js-yaml": "^4.0.9", "@types/node": "^24.6.0", diff --git a/web/src/components/CodeMirrorEditor/index.tsx b/web/src/components/CodeMirrorEditor/index.tsx new file mode 100644 index 00000000..e100b75b --- /dev/null +++ b/web/src/components/CodeMirrorEditor/index.tsx @@ -0,0 +1,150 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-04 17:20:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 17:20:52 + */ +import { useEffect, useRef, useMemo } from 'react'; +import { EditorView, basicSetup } from 'codemirror'; +import { EditorState } from '@codemirror/state'; +import { python } from '@codemirror/lang-python'; +import { javascript } from '@codemirror/lang-javascript'; +import { java } from '@codemirror/lang-java'; +import { cpp } from '@codemirror/lang-cpp'; +import { rust } from '@codemirror/lang-rust'; +import { oneDark } from '@codemirror/theme-one-dark'; + +/** + * Props for the CodeMirrorEditor component + * @property {string} value - The initial code content to display in the editor + * @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust) + * @property {function} onChange - Callback function triggered when editor content changes, receives the new code value + * @property {string} theme - Editor theme, either 'light' or 'dark' + * @property {boolean} readOnly - Whether the editor is read-only + * @property {string} height - Custom height for the editor + * @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font) + */ +interface CodeMirrorEditorProps { + value?: string; + language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust'; + onChange?: (value: string) => void; + theme?: 'light' | 'dark'; + readOnly?: boolean; + height?: string; + size?: 'default' | 'small'; +} + +/** + * Map of language identifiers to their corresponding CodeMirror language extensions + * Supports multiple programming languages with syntax highlighting + */ +const languageExtensions: Record = { + python: python(), + python3: python(), + javascript: javascript(), + typescript: javascript({ typescript: true }), + java: java(), + cpp: cpp(), + c: cpp(), + rust: rust(), +}; + +/** + * CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor + * Provides a code editor with syntax highlighting, theme support, and customizable sizing + * Used in workflow code execution nodes for editing Python and JavaScript code + */ +const CodeMirrorEditor = ({ + value = '', + language = 'javascript', + onChange, + theme = 'light', + readOnly = false, + size, +}: CodeMirrorEditorProps) => { + // Reference to the DOM element that will contain the editor + const editorRef = useRef(null); + // Reference to the CodeMirror EditorView instance + const viewRef = useRef(null); + + /** + * Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes + * Sets up extensions for syntax highlighting, change listeners, and theme + */ + useEffect(() => { + if (!editorRef.current) return; + + // Get the appropriate language extension, fallback to JavaScript if not found + const langExtension = languageExtensions[language] || languageExtensions.javascript; + + // Configure editor extensions + const extensions = [ + basicSetup, // Basic editor features (line numbers, bracket matching, etc.) + langExtension, // Language-specific syntax highlighting + // Listen for document changes and trigger onChange callback + EditorView.updateListener.of((update) => { + if (update.docChanged && onChange) { + onChange(update.state.doc.toString()); + } + }), + EditorState.readOnly.of(readOnly), // Set read-only mode + ]; + + // Apply dark theme if specified + if (theme === 'dark') { + extensions.push(oneDark); + } + + // Create editor state with initial value and extensions + const state = EditorState.create({ + doc: value, + extensions, + }); + + // Create and mount the editor view + viewRef.current = new EditorView({ + state, + parent: editorRef.current, + }); + + // Cleanup: destroy editor instance when component unmounts or dependencies change + return () => { + viewRef.current?.destroy(); + }; + }, [language, theme, readOnly]); + + /** + * Update editor content when the value prop changes externally + * Only updates if the new value differs from current editor content + */ + useEffect(() => { + if (viewRef.current && value !== viewRef.current.state.doc.toString()) { + viewRef.current.dispatch({ + changes: { + from: 0, + to: viewRef.current.state.doc.length, + insert: value, + }, + }); + } + }, [value]); + + // Calculate minimum height based on size prop: small (60px) or default (120px) + const minHeight = useMemo(() => { + return `${size === 'small' ? 60 : 120}px` + }, [size]) + + // Calculate font size based on size prop: small (12px) or default (14px) + const fontSize = useMemo(() => { + return `${size === 'small' ? 12 : 14}px` + }, [size]) + + // Calculate line height based on size prop: small (16px) or default (20px) + const lineHeight = useMemo(() => { + return `${size === 'small' ? 16 : 20}px` + }, [size]) + + return
; +}; + +export default CodeMirrorEditor; diff --git a/web/src/styles/index.css b/web/src/styles/index.css index bbbe9cd9..d937396a 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -180,4 +180,9 @@ body { .x6-node foreignObject > body { min-height: 100%; max-height: 100%; +} + +.ͼ2 .cm-gutters { + background-color: #FFFFFF; + border: none; } \ No newline at end of file diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index 4c8540a8..60da03a7 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin' import InitialValuePlugin from './plugin/InitialValuePlugin'; import CommandPlugin from './plugin/CommandPlugin'; import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin'; -import Python3HighlightPlugin from './plugin/Python3HighlightPlugin'; -import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin'; import LineNumberPlugin from './plugin/LineNumberPlugin'; import BlurPlugin from './plugin/BlurPlugin'; import { VariableNode } from './nodes/VariableNode' @@ -32,7 +30,7 @@ export interface LexicalEditorProps { lineHeight?: number; size?: 'default' | 'small'; type?: 'input' | 'textarea', - language?: 'string' | 'jinja2' | 'python3' | 'javascript' + language?: 'string' | 'jinja2' } const theme = { @@ -67,7 +65,7 @@ const Editor: FC =({ const [enableLineNumbers, setEnableLineNumbers] = useState(false) useEffect(() => { - const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript'; + const needsLineNumbers = language === 'jinja2'; setEnableJinja2(language === 'jinja2'); setEnableLineNumbers(needsLineNumbers); @@ -237,13 +235,11 @@ const Editor: FC =({ {language === 'jinja2' && } - {language === 'python3' && } - {language === 'javascript' && } {enableLineNumbers && } { setCount(count) }} onChange={onChange} /> - - {enableLineNumbers && } + + {enableJinja2 && }
); diff --git a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx deleted file mode 100644 index 21219139..00000000 --- a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx +++ /dev/null @@ -1,182 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; - -const JS_KEYWORDS = new Set([ - 'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default', - 'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import', - 'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try', - 'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined' -]); - -const JavaScriptHighlightPlugin = () => { - const [editor] = useLexicalComposerContext(); - const isPastingRef = useRef(false); - - useEffect(() => { - return editor.registerCommand( - PASTE_COMMAND, - () => { - isPastingRef.current = true; - setTimeout(() => { - isPastingRef.current = false; - }, 100); - return false; - }, - COMMAND_PRIORITY_LOW - ); - }, [editor]); - - useEffect(() => { - return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { - if (isPastingRef.current) return; - - const text = textNode.getTextContent(); - - if (textNode.hasFormat('code')) return; - if (!needsHighlight(text)) return; - if (textNode.getStyle()) return; - - const parent = textNode.getParent(); - if (!parent) return; - - const selection = $getSelection(); - let selectionOffset = null; - if ($isRangeSelection(selection)) { - const anchor = selection.anchor; - if (anchor.getNode() === textNode) { - selectionOffset = anchor.offset; - } - } - - const tokens = tokenizeJavaScript(text); - if (tokens.length <= 1) return; - - const newNodes = tokens.map(token => { - const newNode = $createTextNode(token.text); - newNode.toggleFormat('code'); - - switch (token.type) { - case 'keyword': - newNode.setStyle('color: #d73a49; font-weight: 600;'); - break; - case 'string': - newNode.setStyle('color: #032f62;'); - break; - case 'comment': - newNode.setStyle('color: #6a737d; font-style: italic;'); - break; - case 'number': - newNode.setStyle('color: #005cc5; font-weight: 500;'); - break; - case 'function': - newNode.setStyle('color: #6f42c1; font-weight: 500;'); - break; - } - - return newNode; - }); - - if (newNodes.length > 1) { - textNode.replace(newNodes[0]); - for (let i = 1; i < newNodes.length; i++) { - newNodes[i - 1].insertAfter(newNodes[i]); - } - - if (selectionOffset !== null && $isRangeSelection(selection)) { - let currentOffset = 0; - for (const node of newNodes) { - const nodeLength = node.getTextContent().length; - if (currentOffset + nodeLength >= selectionOffset) { - node.select(selectionOffset - currentOffset, selectionOffset - currentOffset); - break; - } - currentOffset += nodeLength; - } - } - } - }); - }, [editor]); - - return null; -}; - -function needsHighlight(text: string): boolean { - return /[a-zA-Z0-9_/"'`]/.test(text); -} - -function tokenizeJavaScript(text: string): Array<{text: string, type: string}> { - const tokens: Array<{text: string, type: string}> = []; - let i = 0; - - while (i < text.length) { - // Single-line comments - if (text.slice(i, i + 2) === '//') { - let start = i; - while (i < text.length && text[i] !== '\n') i++; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Multi-line comments - if (text.slice(i, i + 2) === '/*') { - let start = i; - i += 2; - while (i < text.length && text.slice(i, i + 2) !== '*/') i++; - if (i < text.length) i += 2; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Strings - if (text[i] === '"' || text[i] === "'" || text[i] === '`') { - const quote = text[i]; - let start = i++; - - while (i < text.length) { - if (text[i] === quote && text[i - 1] !== '\\') { - i++; - break; - } - i++; - } - tokens.push({ text: text.slice(start, i), type: 'string' }); - continue; - } - - // Numbers - if (/\d/.test(text[i])) { - let start = i; - while (i < text.length && /[\d.]/.test(text[i])) i++; - tokens.push({ text: text.slice(start, i), type: 'number' }); - continue; - } - - // Keywords and identifiers - if (/[a-zA-Z_$]/.test(text[i])) { - let start = i; - while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++; - const word = text.slice(start, i); - - if (JS_KEYWORDS.has(word)) { - tokens.push({ text: word, type: 'keyword' }); - } else if (i < text.length && text[i] === '(') { - tokens.push({ text: word, type: 'function' }); - } else { - tokens.push({ text: word, type: 'text' }); - } - continue; - } - - // Other characters - let start = i; - while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++; - if (start < i) { - tokens.push({ text: text.slice(start, i), type: 'text' }); - } - } - - return tokens; -} - -export default JavaScriptHighlightPlugin; diff --git a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx deleted file mode 100644 index 12830ffb..00000000 --- a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx +++ /dev/null @@ -1,177 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; - -const PYTHON_KEYWORDS = new Set([ - 'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue', - 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', - 'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', - 'with', 'yield' -]); - -const Python3HighlightPlugin = () => { - const [editor] = useLexicalComposerContext(); - const isPastingRef = useRef(false); - - useEffect(() => { - return editor.registerCommand( - PASTE_COMMAND, - () => { - isPastingRef.current = true; - setTimeout(() => { - isPastingRef.current = false; - }, 100); - return false; - }, - COMMAND_PRIORITY_LOW - ); - }, [editor]); - - useEffect(() => { - return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { - if (isPastingRef.current) return; - - const text = textNode.getTextContent(); - - if (textNode.hasFormat('code')) return; - if (textNode.getStyle()) return; - if (!needsHighlight(text)) return; - - const parent = textNode.getParent(); - if (!parent) return; - - const selection = $getSelection(); - let selectionOffset = null; - if ($isRangeSelection(selection)) { - const anchor = selection.anchor; - if (anchor.getNode() === textNode) { - selectionOffset = anchor.offset; - } - } - - const tokens = tokenizePython(text); - if (tokens.length <= 1) return; - - const newNodes = tokens.map(token => { - const newNode = $createTextNode(token.text); - newNode.toggleFormat('code'); - - switch (token.type) { - case 'keyword': - newNode.setStyle('color: #d73a49; font-weight: 600;'); - break; - case 'string': - newNode.setStyle('color: #032f62;'); - break; - case 'comment': - newNode.setStyle('color: #6a737d; font-style: italic;'); - break; - case 'number': - newNode.setStyle('color: #005cc5; font-weight: 500;'); - break; - case 'function': - newNode.setStyle('color: #6f42c1; font-weight: 500;'); - break; - } - - return newNode; - }); - - if (newNodes.length > 1) { - textNode.replace(newNodes[0]); - for (let i = 1; i < newNodes.length; i++) { - newNodes[i - 1].insertAfter(newNodes[i]); - } - - if (selectionOffset !== null && $isRangeSelection(selection)) { - let currentOffset = 0; - for (const node of newNodes) { - const nodeLength = node.getTextContent().length; - if (currentOffset + nodeLength >= selectionOffset) { - node.select(selectionOffset - currentOffset, selectionOffset - currentOffset); - break; - } - currentOffset += nodeLength; - } - } - } - }); - }, [editor]); - - return null; -}; - -function needsHighlight(text: string): boolean { - return /[a-zA-Z0-9_#"']/.test(text); -} - -function tokenizePython(text: string): Array<{text: string, type: string}> { - const tokens: Array<{text: string, type: string}> = []; - let i = 0; - - while (i < text.length) { - // Comments - if (text[i] === '#') { - let start = i; - while (i < text.length && text[i] !== '\n') i++; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Strings - if (text[i] === '"' || text[i] === "'") { - const quote = text[i]; - let start = i++; - const isTriple = text.slice(start, start + 3) === quote.repeat(3); - if (isTriple) i += 2; - - while (i < text.length) { - if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) { - i += 3; - break; - } else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') { - i++; - break; - } - i++; - } - tokens.push({ text: text.slice(start, i), type: 'string' }); - continue; - } - - // Numbers - if (/\d/.test(text[i])) { - let start = i; - while (i < text.length && /[\d.]/.test(text[i])) i++; - tokens.push({ text: text.slice(start, i), type: 'number' }); - continue; - } - - // Keywords and identifiers - if (/[a-zA-Z_]/.test(text[i])) { - let start = i; - while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++; - const word = text.slice(start, i); - - if (PYTHON_KEYWORDS.has(word)) { - tokens.push({ text: word, type: 'keyword' }); - } else if (i < text.length && text[i] === '(') { - tokens.push({ text: word, type: 'function' }); - } else { - tokens.push({ text: word, type: 'text' }); - } - continue; - } - - // Other characters - let start = i; - while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++; - if (start < i) { - tokens.push({ text: text.slice(start, i), type: 'text' }); - } - } - - return tokens; -} - -export default Python3HighlightPlugin; diff --git a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx index 8a0ea03e..b9c2c881 100644 --- a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx +++ b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx @@ -5,8 +5,8 @@ import { Node } from '@antv/x6' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import MappingList from '../MappingList' -import Editor from '../../Editor' import OutputList from './OutputList' +import CodeMirrorEditor from '@/components/CodeMirrorEditor'; interface MappingItem { name?: string @@ -110,7 +110,10 @@ const CodeExecution: FC = ({ options }) => { prev.language !== curr.language}> {() => ( - + )} From aad8f0e36b98aa36006a8bacf38bd685f438a1fa Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 17:23:52 +0800 Subject: [PATCH 16/45] [changes]Modify the description of the time for the recent event --- api/app/services/memory_storage_service.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 7ccd145c..d3d267be 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -636,10 +636,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: if m < 1: latest_relative = "刚刚" elif m < 60: - latest_relative = f"{m}分钟前" + latest_relative = "一会前" else: - h = int(m // 60) - latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前" + latest_relative = "较早前" except Exception: pass From 24fbdbd7163d7a9d6b352a4abd824112313928c0 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 17:40:19 +0800 Subject: [PATCH 17/45] [changes]Modify the code based on the AI review --- api/app/controllers/memory_storage_controller.py | 5 +++++ api/app/controllers/ontology_controller.py | 6 +++--- api/app/repositories/memory_config_repository.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index ae372d3b..0b627775 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -195,6 +195,11 @@ def update_config( api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + # 校验至少有一个字段需要更新 + if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 3faa889b..4e244e35 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.validation.owl_validator import OWLValidator from app.services.model_service import ModelConfigService +from app.repositories.ontology_scene_repository import OntologySceneRepository api_logger = get_api_logger() @@ -785,7 +786,7 @@ async def get_scenes_simple( Examples: GET /scenes/simple - 返回: {"items": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]} + 返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]} """ api_logger.info(f"Simple scene list requested by user {current_user.id}") @@ -795,12 +796,11 @@ async def get_scenes_simple( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") - from app.repositories.ontology_scene_repository import OntologySceneRepository repo = OntologySceneRepository(db) scenes = repo.get_simple_list(workspace_id) api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes") - return success(data={"items": scenes}, msg="查询成功") + return success(data=scenes, msg="查询成功") except Exception as e: api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True) diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index e846e20c..acb68ba0 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -279,7 +279,7 @@ class MemoryConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True - if hasattr(update, 'scene_id') and update.scene_id is not None: + if update.scene_id is not None: db_config.scene_id = update.scene_id has_update = True From 8c7a1348cf6dad7d264f8085267b179494b21324 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Feb 2026 17:41:53 +0800 Subject: [PATCH 18/45] feat(web): update memory config ontology api --- web/src/api/ontology.ts | 1 + web/src/views/MemoryManagement/components/MemoryForm.tsx | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/src/api/ontology.ts b/web/src/api/ontology.ts index 4213d362..bb5244e4 100644 --- a/web/src/api/ontology.ts +++ b/web/src/api/ontology.ts @@ -2,6 +2,7 @@ import { request } from '@/utils/request' import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData } from '@/views/Ontology/types' // Scene list +export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple' export const getOntologyScenesUrl = '/memory/ontology/scenes' export const getOntologyScenesList = (data: Query) => { return request.get(getOntologyScenesUrl, data) diff --git a/web/src/views/MemoryManagement/components/MemoryForm.tsx b/web/src/views/MemoryManagement/components/MemoryForm.tsx index 84b0d9c2..0b5b08f7 100644 --- a/web/src/views/MemoryManagement/components/MemoryForm.tsx +++ b/web/src/views/MemoryManagement/components/MemoryForm.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'; import type { MemoryFormData, Memory, MemoryFormRef } from '../types'; import RbModal from '@/components/RbModal' import { createMemoryConfig, updateMemoryConfig } from '@/api/memory' -import { getOntologyScenesUrl } from '@/api/ontology' +import { getOntologyScenesSimpleUrl } from '@/api/ontology' import CustomSelect from '@/components/CustomSelect'; const FormItem = Form.Item; @@ -114,8 +114,7 @@ const MemoryForm = forwardRef(({ > Date: Wed, 4 Feb 2026 17:47:12 +0800 Subject: [PATCH 19/45] fix(web): ui update --- web/src/views/MemoryManagement/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/views/MemoryManagement/index.tsx b/web/src/views/MemoryManagement/index.tsx index a81e53d6..5653317a 100644 --- a/web/src/views/MemoryManagement/index.tsx +++ b/web/src/views/MemoryManagement/index.tsx @@ -97,7 +97,7 @@ const MemoryManagement: React.FC = () => { title={item.config_name} > -
{item.config_desc}
+
{item.config_desc}
From 5ee54f4e0e0337fbfa1c476596bf7057b3104329 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 17:57:43 +0800 Subject: [PATCH 20/45] knowledge_retrieval/bug/fix --- .../core/memory/agent/langgraph_graph/routing/write_router.py | 2 -- api/app/core/memory/agent/langgraph_graph/write_graph.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 29257e88..8935fff2 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -115,8 +115,6 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type, logger.info(f'写入短长期:') # yield db_session finally: - if db_session.in_transaction(): - db_session.rollback() db_session.close() diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index c0e6f86e..6f995ab1 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -63,8 +63,6 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ """方案三:聚合判断""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) finally: - if db_session.in_transaction(): - db_session.rollback() db_session.close() From 918e7285c4f8af7ede474f6885e8db236664863b Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 18:01:05 +0800 Subject: [PATCH 21/45] knowledge_retrieval/bug/fix --- .../core/memory/agent/langgraph_graph/routing/write_router.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 8935fff2..863fa590 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -18,7 +18,6 @@ from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id - logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') From 34276e2066f6abf98520854d870a2c0aab020f38 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 18:06:56 +0800 Subject: [PATCH 22/45] knowledge_retrieval/bug/fix --- .../langgraph_graph/routing/write_router.py | 37 +++++++++---------- .../agent/langgraph_graph/write_graph.py | 31 +++++++--------- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 863fa590..6266d6d2 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -94,27 +94,24 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): with get_db_context() as db_session: - try: - repo = LongTermMemoryRepository(db_session) - await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + repo = LongTermMemoryRepository(db_session) + await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'---------写入短长期-----------') - else: - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - # yield db_session - finally: - db_session.close() '''根据窗口''' diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 6f995ab1..fd2c498c 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -47,23 +47,20 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ write_store.save_session_write(end_user_id, (langchain_messages)) # 获取数据库会话 with get_db_context() as db_session: - try: - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 - service_name="MemoryAgentService" - ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': - """方案三:聚合判断""" - await aggregate_judgment(end_user_id, langchain_messages, memory_config) - finally: - db_session.close() + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" + ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) From 1c8a83140bfda937985bf1597f6a8b692c0c088e Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 4 Feb 2026 18:08:02 +0800 Subject: [PATCH 23/45] feat(workflow): add token usage statistics for question classifier and parameter extraction --- .../core/workflow/nodes/parameter_extractor/node.py | 13 +++++++++++++ .../core/workflow/nodes/question_classifier/node.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index ec58d96c..079cd4cc 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: ParameterExtractorNodeConfig | None = None + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None @staticmethod def _get_prompt(): @@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode): ]) model_resp = await llm.ainvoke(messages) + self.response_metadata = model_resp.response_metadata result = json_repair.repair_json(model_resp.content, return_objects=True) logger.info(f"node: {self.node_id} get params:{result}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 6df410cb..8076dc9d 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" @@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode): response = await llm.ainvoke(messages) result = response.content.strip() + self.response_metadata = response.response_metadata if result in category_names: category = result From 9e6e8f50f8136fb8c963af34d9446dc49a237cad Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Feb 2026 18:36:45 +0800 Subject: [PATCH 24/45] feat(web): move prompt menu --- web/src/routes/index.tsx | 1 - web/src/routes/routes.json | 2 +- web/src/store/menu.json | 30 +++++++++++++++--------------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/web/src/routes/index.tsx b/web/src/routes/index.tsx index 21eaeab8..74cf89ec 100644 --- a/web/src/routes/index.tsx +++ b/web/src/routes/index.tsx @@ -3,7 +3,6 @@ import { createHashRouter, createRoutesFromElements, Route } from 'react-router- // 导入路由配置JSON import routesConfig from './routes.json'; -import Ontology from '@/views/Ontology'; // 递归函数,用于生成路由元素 diff --git a/web/src/routes/routes.json b/web/src/routes/routes.json index b02ebddf..aa5a8178 100644 --- a/web/src/routes/routes.json +++ b/web/src/routes/routes.json @@ -7,6 +7,7 @@ { "path": "/model", "element": "ModelManagement" }, { "path": "/space", "element": "SpaceManagement" }, { "path": "/tool", "element": "ToolManagement" }, + { "path": "/prompt", "element": "Prompt" }, { "path": "/pricing", "element": "Pricing" }, { "path": "/order-pay", "element": "OrderPayment" }, { "path": "/orders", "element": "OrderHistory" }, @@ -35,7 +36,6 @@ { "path": "/reflection-engine/:id", "element": "SelfReflectionEngine" }, { "path": "/space-config", "element": "SpaceConfig" }, { "path": "/ontology", "element": "Ontology" }, - { "path": "/prompt", "element": "Prompt" }, { "path": "/no-permission", "element": "NoPermission" }, { "path": "/*", "element": "NotFound" } ] diff --git a/web/src/store/menu.json b/web/src/store/menu.json index d264e061..45da151e 100644 --- a/web/src/store/menu.json +++ b/web/src/store/menu.json @@ -52,6 +52,21 @@ "sort": 0, "subs": [] }, + { + "id": 20, + "parent": 0, + "code": "prompt", + "label": "提示词", + "i18nKey": "menu.prompt", + "path": "/prompt", + "enable": true, + "display": true, + "level": 1, + "sort": 0, + "icon": null, + "iconActive": null, + "subs": null + }, { "id": 6, "parent": 0, @@ -377,21 +392,6 @@ "iconActive": null, "subs": null }, - { - "id": 20, - "parent": 0, - "code": "prompt", - "label": "提示词", - "i18nKey": "menu.prompt", - "path": "/prompt", - "enable": true, - "display": true, - "level": 1, - "sort": 0, - "icon": null, - "iconActive": null, - "subs": null - }, { "id": 19, "parent": 0, From 7c1f62279754c5611c535bcf1c1630fde7da678f Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 20:11:05 +0800 Subject: [PATCH 25/45] Multiple independent transactions - single transaction --- api/app/core/memory/agent/langgraph_graph/tools/write_tool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index d0be8e5c..9ce581ee 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -1,8 +1,6 @@ import json from langchain_core.messages import HumanMessage, AIMessage - - async def format_parsing(messages: list,type:str='string'): """ 格式化解析消息列表 From 3f906d81cbf9eafbbe14d59ea3167b96bbdc9661 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 20:19:04 +0800 Subject: [PATCH 26/45] Multiple independent transactions - single transaction --- api/app/repositories/neo4j/graph_saver.py | 144 +++++++++++++++++----- 1 file changed, 112 insertions(+), 32 deletions(-) diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 1575315f..f8aa7cdb 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -147,14 +147,14 @@ async def save_statement_entity_edges( async def save_dialog_and_statements_to_neo4j( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - entity_edges: List[EntityEntityEdge], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + entity_edges: List[EntityEntityEdge], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -171,40 +171,120 @@ async def save_dialog_and_statements_to_neo4j( Returns: bool: True if successful, False otherwise """ - try: - # Save all dialogue nodes in batch - dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) - if dialogue_uuids: + + # 定义事务函数,将所有写操作放在一个事务中 + async def _save_all_in_transaction(tx): + """在单个事务中执行所有保存操作,避免死锁""" + results = {} + + # 1. Save all dialogue nodes in batch + if dialogue_nodes: + from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE + dialogue_data = [node.model_dump() for node in dialogue_nodes] + result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) + dialogue_uuids = [record["uuid"] async for record in result] + results['dialogues'] = dialogue_uuids print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") - else: - print("Failed to save dialogues to Neo4j") - return False - # Save all chunk nodes in batch - await save_chunk_nodes(chunk_nodes, connector) + # 2. Save all chunk nodes in batch + if chunk_nodes: + from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE + chunk_data = [node.model_dump() for node in chunk_nodes] + result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data) + chunk_uuids = [record["uuid"] async for record in result] + results['chunks'] = chunk_uuids + print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") - # Save all statement nodes in batch + # 3. Save all statement nodes in batch if statement_nodes: - statement_uuids = await add_statement_nodes(statement_nodes, connector) - if statement_uuids: - print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - else: - print("Failed to save statement nodes to Neo4j") - return False - else: - print("No statement nodes to save") + from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE + statement_data = [node.model_dump() for node in statement_nodes] + result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data) + statement_uuids = [record["uuid"] async for record in result] + results['statements'] = statement_uuids + print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - # Save entities and relationships - await save_entities_and_relationships(entity_nodes, entity_edges, connector) - print("Successfully saved entities and relationships to Neo4j") + # 4. Save entities + if entity_nodes: + from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE + entity_data = [entity.model_dump() for entity in entity_nodes] + result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data) + entity_uuids = [record["uuid"] async for record in result] + results['entities'] = entity_uuids + print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") - # Save new edges - await save_statement_chunk_edges(statement_chunk_edges, connector) - await save_statement_entity_edges(statement_entity_edges, connector) + # 5. Create entity relationships + if entity_edges: + from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE + relationship_data = [] + for edge in entity_edges: + relationship_data.append({ + 'source_id': edge.source, + 'target_id': edge.target, + 'predicate': edge.relation_type, + 'statement_id': edge.source_statement_id, + 'value': edge.relation_value, + 'statement': edge.statement, + 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, + 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, + 'created_at': edge.created_at.isoformat(), + 'expired_at': edge.expired_at.isoformat(), + 'run_id': edge.run_id, + 'end_user_id': edge.end_user_id, + }) + result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data) + rel_uuids = [record["uuid"] async for record in result] + results['entity_relationships'] = rel_uuids + print(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") + # 6. Save statement-chunk edges + if statement_chunk_edges: + from app.repositories.neo4j.cypher_queries import STATEMENT_CHUNK_EDGE_SAVE + sc_edge_data = [] + for edge in statement_chunk_edges: + sc_edge_data.append({ + "id": edge.id, + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat(), + "expired_at": edge.expired_at.isoformat(), + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + }) + result = await tx.run(STATEMENT_CHUNK_EDGE_SAVE, edges=sc_edge_data) + sc_uuids = [record["uuid"] async for record in result] + results['statement_chunk_edges'] = sc_uuids + print(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") + + # 7. Save statement-entity edges + if statement_entity_edges: + from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE + se_edge_data = [] + for edge in statement_entity_edges: + se_edge_data.append({ + "id": edge.id, + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat(), + "expired_at": edge.expired_at.isoformat(), + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + }) + result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, edges=se_edge_data) + se_uuids = [record["uuid"] async for record in result] + results['statement_entity_edges'] = se_uuids + print(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + + return results + + try: + # 使用显式写事务执行所有操作,避免死锁 + results = await connector.execute_write_transaction(_save_all_in_transaction) + print("Successfully saved all data to Neo4j in a single transaction") return True except Exception as e: print(f"Neo4j integration error: {e}") print("Continuing without database storage...") return False + From 3735bdde194ff44ba7b625118098f07e83fc3737 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 20:20:45 +0800 Subject: [PATCH 27/45] Multiple independent transactions - single transaction --- .../core/memory/agent/utils/write_tools.py | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 446ab86a..aa66014c 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline This module provides the main write function for executing the knowledge extraction pipeline. Only MemoryConfig is needed - clients are constructed internally. """ +import asyncio import time from datetime import datetime @@ -123,23 +124,48 @@ async def write( except Exception as e: logger.error(f"Error creating indexes: {e}", exc_info=True) + # 添加死锁重试机制 + max_retries = 3 + retry_delay = 1 # 秒 + + for attempt in range(max_retries): + try: + success = await save_dialog_and_statements_to_neo4j( + dialogue_nodes=all_dialogue_nodes, + chunk_nodes=all_chunk_nodes, + statement_nodes=all_statement_nodes, + entity_nodes=all_entity_nodes, + statement_chunk_edges=all_statement_chunk_edges, + statement_entity_edges=all_statement_entity_edges, + entity_edges=all_entity_entity_edges, + connector=neo4j_connector + ) + if success: + logger.info("Successfully saved all data to Neo4j") + break + else: + logger.warning("Failed to save some data to Neo4j") + if attempt < max_retries - 1: + logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + except Exception as e: + error_msg = str(e) + # 检查是否是死锁错误 + if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower(): + if attempt < max_retries - 1: + logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + else: + logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}") + raise + else: + # 非死锁错误,直接抛出 + raise + try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=all_dialogue_nodes, - chunk_nodes=all_chunk_nodes, - statement_nodes=all_statement_nodes, - entity_nodes=all_entity_nodes, - statement_chunk_edges=all_statement_chunk_edges, - statement_entity_edges=all_statement_entity_edges, - entity_edges=all_entity_entity_edges, - connector=neo4j_connector - ) - if success: - logger.info("Successfully saved all data to Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - finally: await neo4j_connector.close() + except Exception as e: + logger.error(f"Error closing Neo4j connector: {e}") log_time("Neo4j Database Save", time.time() - step_start, log_file) From 657d48a5f9b6321df8d67fa85ed5d3d89ccb8c9b Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 4 Feb 2026 20:25:45 +0800 Subject: [PATCH 28/45] Multiple independent transactions - single transaction --- api/app/repositories/neo4j/graph_saver.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index f8aa7cdb..1866fdb7 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, EntityEntityEdge, ) - +import logging +logger = logging.getLogger(__name__) async def save_entities_and_relationships( entity_nodes: List[ExtractedEntityNode], entity_entity_edges: List[EntityEntityEdge], @@ -193,7 +194,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data) chunk_uuids = [record["uuid"] async for record in result] results['chunks'] = chunk_uuids - print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") + logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") # 3. Save all statement nodes in batch if statement_nodes: @@ -202,7 +203,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data) statement_uuids = [record["uuid"] async for record in result] results['statements'] = statement_uuids - print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") + logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") # 4. Save entities if entity_nodes: @@ -211,7 +212,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data) entity_uuids = [record["uuid"] async for record in result] results['entities'] = entity_uuids - print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") + logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") # 5. Create entity relationships if entity_edges: @@ -235,7 +236,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data) rel_uuids = [record["uuid"] async for record in result] results['entity_relationships'] = rel_uuids - print(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") + logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") # 6. Save statement-chunk edges if statement_chunk_edges: @@ -254,7 +255,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(STATEMENT_CHUNK_EDGE_SAVE, edges=sc_edge_data) sc_uuids = [record["uuid"] async for record in result] results['statement_chunk_edges'] = sc_uuids - print(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") + logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") # 7. Save statement-entity edges if statement_entity_edges: @@ -273,7 +274,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, edges=se_edge_data) se_uuids = [record["uuid"] async for record in result] results['statement_entity_edges'] = se_uuids - print(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") return results From 3364374dc6ccb35443bbb6ccd50c85297e0f5714 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Thu, 5 Feb 2026 10:50:10 +0800 Subject: [PATCH 29/45] Write Missing None (#321) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- api/app/repositories/neo4j/graph_saver.py | 33 ++++++++++++++--------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 1866fdb7..5099fd01 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -42,8 +42,8 @@ async def save_entities_and_relationships( 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, - 'created_at': edge.created_at.isoformat(), - 'expired_at': edge.expired_at.isoformat(), + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, } @@ -228,8 +228,8 @@ async def save_dialog_and_statements_to_neo4j( 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, - 'created_at': edge.created_at.isoformat(), - 'expired_at': edge.expired_at.isoformat(), + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, }) @@ -240,19 +240,19 @@ async def save_dialog_and_statements_to_neo4j( # 6. Save statement-chunk edges if statement_chunk_edges: - from app.repositories.neo4j.cypher_queries import STATEMENT_CHUNK_EDGE_SAVE + from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE sc_edge_data = [] for edge in statement_chunk_edges: sc_edge_data.append({ "id": edge.id, "source": edge.source, "target": edge.target, - "created_at": edge.created_at.isoformat(), - "expired_at": edge.expired_at.isoformat(), + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, "run_id": edge.run_id, "end_user_id": edge.end_user_id, }) - result = await tx.run(STATEMENT_CHUNK_EDGE_SAVE, edges=sc_edge_data) + result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data) sc_uuids = [record["uuid"] async for record in result] results['statement_chunk_edges'] = sc_uuids logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") @@ -263,15 +263,15 @@ async def save_dialog_and_statements_to_neo4j( se_edge_data = [] for edge in statement_entity_edges: se_edge_data.append({ - "id": edge.id, "source": edge.source, "target": edge.target, - "created_at": edge.created_at.isoformat(), - "expired_at": edge.expired_at.isoformat(), + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, "run_id": edge.run_id, "end_user_id": edge.end_user_id, + "connect_strength": getattr(edge, "connect_strength", "strong"), }) - result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, edges=se_edge_data) + result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data) se_uuids = [record["uuid"] async for record in result] results['statement_entity_edges'] = se_uuids logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") @@ -281,10 +281,17 @@ async def save_dialog_and_statements_to_neo4j( try: # 使用显式写事务执行所有操作,避免死锁 results = await connector.execute_write_transaction(_save_all_in_transaction) - print("Successfully saved all data to Neo4j in a single transaction") + summary = { + key: len(value) + for key, value in results.items() + if isinstance(value, (list, tuple, set)) + } + logger.info("Transaction completed. Summary: %s", summary) + logger.debug("Full transaction results: %r", results) return True except Exception as e: + logger.error(f"Neo4j integration error: {e}", exc_info=True) print(f"Neo4j integration error: {e}") print("Continuing without database storage...") return False From 46ed7e38bf4e623f5d416c6508f280dd15484072 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Thu, 5 Feb 2026 12:11:45 +0800 Subject: [PATCH 30/45] Fix/release memory bug (#324) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None * redis update * redis update * redis update * redis update --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- api/app/core/memory/agent/utils/redis_tool.py | 97 +++++++++++++------ 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index b61319e5..c5729628 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -294,6 +294,7 @@ class RedisCountStore: """ session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") + index_key = f'session:count:index:{end_user_id}' # 索引键 pipe = self.r.pipeline() pipe.hset(key, mapping={ @@ -304,6 +305,10 @@ class RedisCountStore: "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 + + # 创建索引:end_user_id -> session_id 映射 + pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) + result = pipe.execute() print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") @@ -320,31 +325,47 @@ class RedisCountStore: list 或 False: 如果找到返回 [count, messages],否则返回 False """ try: - search_pattern = 'session:count:*' + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' - for key in self.r.keys(search_pattern): - data = self.r.hgetall(key) - - if not data: - continue - - if data.get('end_user_id') == end_user_id: - count = data.get('count') - messages_str = data.get('messages') - - if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + self.r.delete(index_key) + return False + except Exception as type_error: + print(f"[get_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + return False + + # 直接获取数据 + key = generate_session_key(session_id, key_type="count") + data = self.r.hgetall(key) + + if not data: + # 索引存在但数据不存在,清理索引 + self.r.delete(index_key) + return False + + count = data.get('count') + messages_str = data.get('messages') + + if count is not None: + messages = deserialize_messages(messages_str) + return [int(count), messages] return False except Exception as e: print(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, messages: Any) -> bool: """ - 通过 end_user_id 修改访问次数统计 + 通过 end_user_id 修改访问次数统计(优化版:使用索引) Args: end_user_id: 终端用户ID @@ -355,23 +376,39 @@ class RedisCountStore: bool: 更新成功返回 True,未找到记录返回 False """ try: + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' + + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + # 索引键类型错误,删除并返回 False + print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + self.r.delete(index_key) + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + except Exception as type_error: + print(f"[update_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + + # 直接更新数据 + key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - search_pattern = 'session:count:*' - for key in self.r.keys(search_pattern): - data = self.r.hgetall(key) - - if not data: - continue - - if data.get('end_user_id') == end_user_id: - self.r.hset(key, 'count', int(new_count)) - self.r.hset(key, 'messages', messages_str) - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") - return True + pipe = self.r.pipeline() + pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'messages', messages_str) + result = pipe.execute() + + print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + return True - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") - return False except Exception as e: print(f"[update_sessions_count] 更新失败: {e}") return False From 07e698265e5a8b1c467a48990fec559483775092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Thu, 5 Feb 2026 13:50:04 +0800 Subject: [PATCH 31/45] Fix/writer memory bug (#326) * [fix]Fix the bug * [fix]Fix the bug * [fix]Correct the direction indication. --- api/app/repositories/neo4j/add_edges.py | 6 +-- api/app/repositories/neo4j/add_nodes.py | 4 +- api/app/utils/config_utils.py | 58 ++++++++++++++++++------- 3 files changed, 48 insertions(+), 20 deletions(-) diff --git a/api/app/repositories/neo4j/add_edges.py b/api/app/repositories/neo4j/add_edges.py index 162bf411..2b32551c 100644 --- a/api/app/repositories/neo4j/add_edges.py +++ b/api/app/repositories/neo4j/add_edges.py @@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], try: edges: List[dict] = [] for s in summaries: - for chunk_id in getattr(s, "chunk_ids", []) or []: + chunk_ids = getattr(s, "chunk_ids", []) or [] + for chunk_id in chunk_ids: edges.append({ "summary_id": s.id, "chunk_id": chunk_id, @@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], if not edges: return [] - result = await connector.execute_query( MEMORY_SUMMARY_STATEMENT_EDGE_SAVE, edges=edges ) created = [record.get("uuid") for record in result] if result else [] return created - except Exception: + except Exception as e: return None diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index fcf700b5..42c178b3 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids - except Exception: + except Exception as e: + print(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/utils/config_utils.py b/api/app/utils/config_utils.py index cc67afd2..55cfe8a3 100644 --- a/api/app/utils/config_utils.py +++ b/api/app/utils/config_utils.py @@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports. """ from uuid import UUID from sqlalchemy.orm import Session +import uuid as uuid_module -def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID: +def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID: """ - 解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID + 解析 config_id,支持 UUID、UUID字符串、整数等多种格式 Args: - config_id: 配置ID(UUID 或整数) + config_id: 配置ID(UUID、UUID字符串 或 整数) db: 数据库会话 Returns: UUID: 解析后的配置ID Raises: - ValueError: 当找不到对应的配置时 + ValueError: 当找不到对应的配置时或格式无效时 """ - from app.models.memory_config_model import MemoryConfig - if isinstance(config_id, UUID): + + # 1. 如果已经是 UUID 类型,直接返回 + if isinstance(config_id, UUID): return config_id - if isinstance(config_id, str) and len(config_id)<=6: - memory_config = db.query(MemoryConfig).filter( - MemoryConfig.config_id_old == int(config_id) - ).first() - print(memory_config) - if not memory_config: - raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置") - return memory_config.config_id + + # 2. 如果是字符串类型 + if isinstance(config_id, str): + config_id_stripped = config_id.strip() + + # 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36) + try: + return uuid_module.UUID(config_id_stripped) + except ValueError: + pass + + # 2.2 尝试解析为整数(用于查询 config_id_old) + try: + old_id = int(config_id_stripped) + if old_id > 0: + memory_config = db.query(MemoryConfig).filter( + MemoryConfig.config_id_old == old_id + ).first() + if not memory_config: + raise ValueError(f"未找到 config_id_old={old_id} 对应的配置") + return memory_config.config_id + except ValueError: + pass + + # 2.3 无法解析的字符串格式 + raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)") + + # 3. 如果是整数类型,通过 config_id_old 查找 if isinstance(config_id, int): + if config_id <= 0: + raise ValueError(f"config_id 必须是正整数: {config_id}") + memory_config = db.query(MemoryConfig).filter( MemoryConfig.config_id_old == config_id ).first() if not memory_config: - raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置") + raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id - return config_id + # 4. 不支持的类型 + raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}") From 169e01276d17ba2efa0c04e05af9b3e9076cb6ed Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Feb 2026 13:57:25 +0800 Subject: [PATCH 32/45] fix(web): markdown table ui update --- web/src/components/Markdown/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/Markdown/index.tsx b/web/src/components/Markdown/index.tsx index 6737f15a..1a2c765d 100644 --- a/web/src/components/Markdown/index.tsx +++ b/web/src/components/Markdown/index.tsx @@ -51,7 +51,7 @@ const components = { audio: ({ src, ...props }: any) => , a: ({ href, children, ...props }: any) => {children}, button: ({ children }: any) => {[children]}, - table: ({ children, ...props }: any) => {children}
, + table: ({ children, ...props }: any) =>
{children}
, tr: ({ children, ...props }: any) => {children}, th: ({ children, ...props }: any) => {children}, td: ({ children, ...props }: any) => {children}, From aca7d250011554b6e2dab9a2a0a7fbc2ccf80555 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Thu, 5 Feb 2026 15:22:15 +0800 Subject: [PATCH 33/45] Fix/release memory bug (#332) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None * redis update * redis update * redis update * redis update * writer_dup_bug/fix --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- .../core/memory/agent/langgraph_graph/routing/write_router.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 6266d6d2..895f61ac 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -95,8 +95,7 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): with get_db_context() as db_session: repo = LongTermMemoryRepository(db_session) - await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) From 47b25d7a2674cda0945d43e3eebe397d210e61a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Thu, 5 Feb 2026 15:56:43 +0800 Subject: [PATCH 34/45] Fix/fact summary (#333) * [fix]Disable the contents related to fact_summary * [fix]Disable the contents related to fact_summary * [fix]Modify the code based on the AI review --- .../agent/langgraph_graph/tools/tool.py | 3 +- api/app/core/memory/models/graph_models.py | 3 +- .../deduplication/deduped_and_disamb.py | 65 ++++++++++--------- .../deduplication/entity_dedup_llm.py | 23 ++++--- .../deduplication/second_layer_dedup.py | 3 +- .../extraction_orchestrator.py | 3 +- api/app/core/memory/utils/alias_utils.py | 4 +- .../utils/prompt/prompts/entity_dedup.jinja2 | 6 +- .../repositories/memory_config_repository.py | 3 +- api/app/repositories/neo4j/cypher_queries.py | 12 ++-- 10 files changed, 73 insertions(+), 52 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index c4814de1..fcbb18e3 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): 清理后的数据 """ # 需要过滤的字段列表 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 fields_to_remove = { 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" } if isinstance(data, dict): diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 79b88fdc..1880b9ab 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -413,7 +413,8 @@ class ExtractedEntityNode(Node): description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") - fact_summary: str = Field(default="", description="Summary of the fact about this entity") + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index a425e0ed..f2f14d9e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): if len(desc_b) > len(desc_a): canonical.description = desc_b # 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序 - fact_a = getattr(canonical, "fact_summary", "") or "" - fact_b = getattr(ent, "fact_summary", "") or "" - def _extract_sources(txt: str) -> List[str]: - sources: List[str] = [] - if not txt: - return sources - for line in str(txt).splitlines(): - ln = line.strip() + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = getattr(canonical, "fact_summary", "") or "" + # fact_b = getattr(ent, "fact_summary", "") or "" + # def _extract_sources(txt: str) -> List[str]: + # sources: List[str] = [] + # if not txt: + # return sources + # for line in str(txt).splitlines(): + # ln = line.strip() # 支持“来源:”或“来源:”前缀 - m = re.match(r"^来源[::]\s*(.+)$", ln) - if m: - content = m.group(1).strip() - if content: - sources.append(content) + # m = re.match(r"^来源[::]\s*(.+)$", ln) + # if m: + # content = m.group(1).strip() + # if content: + # sources.append(content) # 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失 - if not sources and txt.strip(): - sources.append(txt.strip()) - return sources + # if not sources and txt.strip(): + # sources.append(txt.strip()) + # return sources try: - src_a = _extract_sources(fact_a) - src_b = _extract_sources(fact_b) - seen = set() - merged_sources: List[str] = [] - for s in src_a + src_b: - if s and s not in seen: - seen.add(s) - merged_sources.append(s) - if merged_sources: - name_line = f"实体: {getattr(canonical, 'name', '')}".strip() - canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) - elif fact_b and not fact_a: - canonical.fact_summary = fact_b + # src_a = _extract_sources(fact_a) + # src_b = _extract_sources(fact_b) + # seen = set() + # merged_sources: List[str] = [] + # for s in src_a + src_b: + # if s and s not in seen: + # seen.add(s) + # merged_sources.append(s) + # if merged_sources: + # name_line = f"实体: {getattr(canonical, 'name', '')}".strip() + # canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) + # elif fact_b and not fact_a: + # canonical.fact_summary = fact_b + pass except Exception: # 兜底:若解析失败,保留较长文本 - if len(fact_b) > len(fact_a): - canonical.fact_summary = fact_b + # if len(fact_b) > len(fact_a): + # canonical.fact_summary = fact_b + pass except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 0249ac1f..a028e916 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # # 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整) desc_a = (getattr(a, "description", "") or "") desc_b = (getattr(b, "description", "") or "") - fact_a = (getattr(a, "fact_summary", "") or "") - fact_b = (getattr(b, "fact_summary", "") or "") - score_a = len(desc_a) + len(fact_a) - score_b = len(desc_b) + len(fact_b) + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = (getattr(a, "fact_summary", "") or "") + # fact_b = (getattr(b, "fact_summary", "") or "") + # score_a = len(desc_a) + len(fact_a) + # score_b = len(desc_b) + len(fact_b) + score_a = len(desc_a) + score_b = len(desc_b) if score_a != score_b: return 0 if score_a >= score_b else 1 return 0 @@ -189,7 +192,8 @@ async def _judge_pair( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -197,7 +201,8 @@ async def _judge_pair( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } # 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式) @@ -248,7 +253,8 @@ async def _judge_pair_disamb( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -256,7 +262,8 @@ async def _judge_pair_disamb( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } prompt = render_entity_dedup_prompt( diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index dbc697d9..028a926f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: description=row.get("description") or "", aliases=row.get("aliases") or [], name_embedding=row.get("name_embedding") or [], - fact_summary=row.get("fact_summary") or "", + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=row.get("fact_summary") or "", connect_strength=row.get("connect_strength") or "", ) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 7b7e854b..8a99cb40 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1085,7 +1085,8 @@ class ExtractionOrchestrator: entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type description=getattr(entity, 'description', ''), # 添加必需的 description 字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段 - fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), diff --git a/api/app/core/memory/utils/alias_utils.py b/api/app/core/memory/utils/alias_utils.py index df75752a..ff139128 100644 --- a/api/app/core/memory/utils/alias_utils.py +++ b/api/app/core/memory/utils/alias_utils.py @@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li key=lambda eid: ( _strength_rank(eid), len(getattr(entity_by_id.get(eid), 'description', '') or ''), - len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + 0 # 临时占位 ), reverse=True ) diff --git a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 index be53c9d4..7fb465a2 100644 --- a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 @@ -9,7 +9,8 @@ - 类型: "{{ entity_a.entity_type | default('') }}" - 描述: "{{ entity_a.description | default('') }}" - 别名: {{ entity_a.aliases | default([]) }} -- 摘要: "{{ entity_a.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_a.connect_strength | default('') }}" 实体B: @@ -17,7 +18,8 @@ - 类型: "{{ entity_b.entity_type | default('') }}" - 描述: "{{ entity_b.description | default('') }}" - 别名: {{ entity_b.aliases | default([]) }} -- 摘要: "{{ entity_b.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_b.connect_strength | default('') }}" 上下文: diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index acb68ba0..68e7cb04 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -86,7 +86,8 @@ class MemoryConfigRepository: n.description AS description, n.entity_type AS entity_type, n.name AS name, - COALESCE(n.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(n.fact_summary, '') AS fact_summary, n.end_user_id AS end_user_id, n.apply_id AS apply_id, n.user_id AS user_id, diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index cf1732fd..aabd0050 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity e.name_embedding = CASE WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding ELSE e.name_embedding END, - e.fact_summary = CASE - WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' - AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) - THEN entity.fact_summary ELSE e.fact_summary END, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // e.fact_summary = CASE + // WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' + // AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) + // THEN entity.fact_summary ELSE e.fact_summary END, e.connect_strength = CASE WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength ELSE CASE @@ -321,7 +322,8 @@ RETURN e.id AS id, e.description AS description, e.aliases AS aliases, e.name_embedding AS name_embedding, - COALESCE(e.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(e.fact_summary, '') AS fact_summary, e.connect_strength AS connect_strength, collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT c.id) AS chunk_ids, From 4e7ab3d7e3569a139c59b076434b7db0530f0bda Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Thu, 5 Feb 2026 17:27:28 +0800 Subject: [PATCH 35/45] Fix/release memory bug (#335) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None * redis update * redis update * redis update * redis update * writer_dup_bug/fix * writer_graph_bug/fix * writer_graph_bug/fix --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- api/app/repositories/neo4j/cypher_queries.py | 55 ++++++++++++++++++++ api/app/services/user_memory_service.py | 18 ++----- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index aabd0050..651c513f 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1004,3 +1004,58 @@ RETURN DISTINCT x.statement as statement,x.created_at as created_at """ +Graph_Node_query = """ + MATCH (n:MemorySummary) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Dialogue) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT 1 + + UNION ALL + + MATCH (n:Statement) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:ExtractedEntity) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Chunk) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority + LIMIT $limit + + """ \ No newline at end of file diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 3a90a821..d5f03e85 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.conversation_repository import ConversationRepository from app.repositories.end_user_repository import EndUserRepository +from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.services.implicit_memory_service import ImplicitMemoryService @@ -1508,7 +1509,6 @@ async def analytics_graph_data( user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) - if not end_user: logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") return { @@ -1562,21 +1562,11 @@ async def analytics_graph_data( } else: # 查询所有节点 - node_query = """ - MATCH (n) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) as id, - labels(n)[0] as label, - properties(n) as properties - LIMIT $limit - """ + node_query=Graph_Node_query node_params = { "end_user_id": end_user_id, "limit": limit } - - # 执行节点查询 node_results = await _neo4j_connector.execute_query(node_query, **node_params) @@ -1587,9 +1577,9 @@ async def analytics_graph_data( for record in node_results: node_id = record["id"] - node_label = record["label"] + node_labels = record.get("labels", []) + node_label = node_labels[0] if node_labels else "Unknown" node_props = record["properties"] - # 根据节点类型提取需要的属性字段 filtered_props = await _extract_node_properties(node_label, node_props,node_id) From fe3c31c08cbb2130b6d869500b7fc919d9092359 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 6 Feb 2026 11:11:40 +0800 Subject: [PATCH 36/45] Revert "feat(web): move prompt menu" This reverts commit 9e6e8f50f8136fb8c963af34d9446dc49a237cad. --- web/src/routes/index.tsx | 1 + web/src/routes/routes.json | 2 +- web/src/store/menu.json | 30 +++++++++++++++--------------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/web/src/routes/index.tsx b/web/src/routes/index.tsx index 74cf89ec..21eaeab8 100644 --- a/web/src/routes/index.tsx +++ b/web/src/routes/index.tsx @@ -3,6 +3,7 @@ import { createHashRouter, createRoutesFromElements, Route } from 'react-router- // 导入路由配置JSON import routesConfig from './routes.json'; +import Ontology from '@/views/Ontology'; // 递归函数,用于生成路由元素 diff --git a/web/src/routes/routes.json b/web/src/routes/routes.json index aa5a8178..b02ebddf 100644 --- a/web/src/routes/routes.json +++ b/web/src/routes/routes.json @@ -7,7 +7,6 @@ { "path": "/model", "element": "ModelManagement" }, { "path": "/space", "element": "SpaceManagement" }, { "path": "/tool", "element": "ToolManagement" }, - { "path": "/prompt", "element": "Prompt" }, { "path": "/pricing", "element": "Pricing" }, { "path": "/order-pay", "element": "OrderPayment" }, { "path": "/orders", "element": "OrderHistory" }, @@ -36,6 +35,7 @@ { "path": "/reflection-engine/:id", "element": "SelfReflectionEngine" }, { "path": "/space-config", "element": "SpaceConfig" }, { "path": "/ontology", "element": "Ontology" }, + { "path": "/prompt", "element": "Prompt" }, { "path": "/no-permission", "element": "NoPermission" }, { "path": "/*", "element": "NotFound" } ] diff --git a/web/src/store/menu.json b/web/src/store/menu.json index 45da151e..d264e061 100644 --- a/web/src/store/menu.json +++ b/web/src/store/menu.json @@ -52,21 +52,6 @@ "sort": 0, "subs": [] }, - { - "id": 20, - "parent": 0, - "code": "prompt", - "label": "提示词", - "i18nKey": "menu.prompt", - "path": "/prompt", - "enable": true, - "display": true, - "level": 1, - "sort": 0, - "icon": null, - "iconActive": null, - "subs": null - }, { "id": 6, "parent": 0, @@ -392,6 +377,21 @@ "iconActive": null, "subs": null }, + { + "id": 20, + "parent": 0, + "code": "prompt", + "label": "提示词", + "i18nKey": "menu.prompt", + "path": "/prompt", + "enable": true, + "display": true, + "level": 1, + "sort": 0, + "icon": null, + "iconActive": null, + "subs": null + }, { "id": 19, "parent": 0, From 623aaf8a0e6dd28450769a967963e05f2eaaf14e Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 6 Feb 2026 11:28:19 +0800 Subject: [PATCH 37/45] feat(web): use memory_config_id replace memory_content --- web/src/views/ApplicationConfig/Agent.tsx | 14 +++++++------- .../ApplicationConfig/components/Skill/index.tsx | 4 ++-- web/src/views/ApplicationConfig/types.ts | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 0bfd4ba7..6feb1548 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:21 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-04 20:16:45 + * @Last Modified time: 2026-02-06 11:20:14 */ import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react'; import clsx from 'clsx' @@ -38,8 +38,8 @@ import CustomSelect from '@/components/CustomSelect' import aiPrompt from '@/assets/images/application/aiPrompt.png' import AiPromptModal from './components/AiPromptModal' import ToolList from './components/ToolList/ToolList' -import ChatVariableConfigModal from './components/ChatVariableConfigModal'; import SkillList from './components/Skill' +import ChatVariableConfigModal from './components/ChatVariableConfigModal'; import type { Skill } from '@/views/Skills/types' /** @@ -169,7 +169,7 @@ const Agent = forwardRef((_props, ref) => { const { skills } = response let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] let allTools = Array.isArray(response.tools) ? response.tools : [] - const memoryContent = response.memory?.memory_content + const memoryContent = response.memory?.memory_config_id const parsedMemoryContent = memoryContent === null || memoryContent === '' ? undefined : !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent @@ -178,7 +178,7 @@ const Agent = forwardRef((_props, ref) => { tools: allTools, memory: { ...response.memory, - memory_content: parsedMemoryContent + memory_config_id: parsedMemoryContent }, skills: { ...skills, @@ -262,7 +262,7 @@ const Agent = forwardRef((_props, ref) => { if (!isSave || !data) return Promise.resolve() const { memory, knowledge_retrieval, tools, skills, ...rest } = values const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {} - const { memory_content } = memory || {} + const { memory_config_id } = memory || {} // Get other necessary properties of memory from original data const originalMemory = data.memory || ({} as MemoryConfig) @@ -272,7 +272,7 @@ const Agent = forwardRef((_props, ref) => { memory: { ...originalMemory, ...memory, - memory_content: memory_content ? String(memory_content) : '', + memory_config_id: memory_config_id ? String(memory_config_id) : '', }, knowledge_retrieval: knowledge_bases.length > 0 ? { ...data.knowledge_retrieval, @@ -444,7 +444,7 @@ const Agent = forwardRef((_props, ref) => { diff --git a/web/src/views/ApplicationConfig/components/Skill/index.tsx b/web/src/views/ApplicationConfig/components/Skill/index.tsx index 1a8dcc6d..d42edd3d 100644 --- a/web/src/views/ApplicationConfig/components/Skill/index.tsx +++ b/web/src/views/ApplicationConfig/components/Skill/index.tsx @@ -39,7 +39,7 @@ const processObj = [ * @param value - Current skill configuration values * @param onChange - Callback function when configuration changes */ -const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => { +const SkillList: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) => void}> = () => { const { t } = useTranslation() const form = Form.useFormInstance() const skillConfig = Form.useWatch(['skills'], form) @@ -148,4 +148,4 @@ const Skill: FC<{value?: SkillConfigForm; onChange?: (config: SkillConfigForm) = ) } -export default Skill \ No newline at end of file +export default SkillList \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index fc799b91..2d09f739 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -43,7 +43,7 @@ export interface MemoryConfig { /** Whether memory is enabled */ enabled: boolean; /** Memory content */ - memory_content?: string; + memory_config_id?: string; /** Maximum history length */ max_history?: number | string; } From 447d8790add4f6a3726125fccd28fdf7e5c3334f Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 6 Feb 2026 12:02:21 +0800 Subject: [PATCH 38/45] fix(web): ui update --- .../views/ModelManagement/components/MultiKeyConfigModal.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx index 2638f10c..5e362025 100644 --- a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx +++ b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx @@ -82,7 +82,7 @@ const MultiKeyConfigModal = forwardRef {model.api_keys.map((key) => (
-
+
{key.api_key}
{key.api_base}
From 677a603835219209f31f2d477ed4d28e61c77875 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 6 Feb 2026 12:15:49 +0800 Subject: [PATCH 39/45] fix(web): update text --- .../views/ApplicationConfig/components/Knowledge/Knowledge.tsx | 2 +- .../Workflow/components/Properties/Knowledge/Knowledge.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx index 1e59f26d..82fb4e59 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx @@ -117,7 +117,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi title={t('application.knowledgeBaseAssociation')} extra={ - + } diff --git a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx index da9603c8..3cd7efcd 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx @@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
- {t('workflow.config.knowledge-retrieval.recallConfig')} + {t('application.globalConfig')}
From db46c186aac6c33f1b1dc198776d89dd3c6c15ec Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Fri, 6 Feb 2026 12:18:40 +0800 Subject: [PATCH 40/45] [ADD]Three party synchronization 1. Three party web website data access - Web site synchronization Building a knowledge base by crawling web page data in batches through web crawlers Web site synchronization utilizes crawler technology, which can automatically capture all websites under the same domain name through a single entry website. Currently, it supports up to 200 subpages. For compliance and security reasons, only static site crawling is supported, mainly used for quickly building knowledge bases on various document sites. 2. Feishu Knowledge Base By configuring Feishu document permissions, a knowledge base can be built using Feishu documents, and the documents will not undergo secondary storage 3. Language Bird Knowledge Base You can configure the permissions of the language bird document to build a knowledge base using the language bird document, and the document will not undergo secondary storage --- api/app/celery_app.py | 1 + api/app/controllers/knowledge_controller.py | 101 +++- api/app/core/rag/crawler/__init__.py | 0 api/app/core/rag/crawler/__main__.py | 89 +++ api/app/core/rag/crawler/content_extractor.py | 233 ++++++++ api/app/core/rag/crawler/http_fetcher.py | 302 ++++++++++ api/app/core/rag/crawler/models.py | 52 ++ api/app/core/rag/crawler/rate_limiter.py | 57 ++ api/app/core/rag/crawler/robots_parser.py | 118 ++++ api/app/core/rag/crawler/url_normalizer.py | 171 ++++++ api/app/core/rag/crawler/web_crawler.py | 215 +++++++ api/app/core/rag/integrations/__init__.py | 1 + .../core/rag/integrations/feishu/__init__.py | 1 + .../core/rag/integrations/feishu/__main__.py | 84 +++ .../core/rag/integrations/feishu/client.py | 452 +++++++++++++++ .../rag/integrations/feishu/exceptions.py | 46 ++ .../core/rag/integrations/feishu/models.py | 17 + api/app/core/rag/integrations/feishu/retry.py | 137 +++++ .../core/rag/integrations/yuque/__init__.py | 1 + .../core/rag/integrations/yuque/__main__.py | 77 +++ api/app/core/rag/integrations/yuque/client.py | 544 ++++++++++++++++++ .../core/rag/integrations/yuque/exceptions.py | 46 ++ api/app/core/rag/integrations/yuque/models.py | 42 ++ api/app/core/rag/integrations/yuque/retry.py | 134 +++++ api/app/models/file_model.py | 1 + api/app/models/knowledge_model.py | 11 + api/app/schemas/file_schema.py | 3 + api/app/tasks.py | 483 ++++++++++++++++ api/pyproject.toml | 2 + api/requirements.txt | 2 + 30 files changed, 3422 insertions(+), 1 deletion(-) create mode 100644 api/app/core/rag/crawler/__init__.py create mode 100644 api/app/core/rag/crawler/__main__.py create mode 100644 api/app/core/rag/crawler/content_extractor.py create mode 100644 api/app/core/rag/crawler/http_fetcher.py create mode 100644 api/app/core/rag/crawler/models.py create mode 100644 api/app/core/rag/crawler/rate_limiter.py create mode 100644 api/app/core/rag/crawler/robots_parser.py create mode 100644 api/app/core/rag/crawler/url_normalizer.py create mode 100644 api/app/core/rag/crawler/web_crawler.py create mode 100644 api/app/core/rag/integrations/__init__.py create mode 100644 api/app/core/rag/integrations/feishu/__init__.py create mode 100644 api/app/core/rag/integrations/feishu/__main__.py create mode 100644 api/app/core/rag/integrations/feishu/client.py create mode 100644 api/app/core/rag/integrations/feishu/exceptions.py create mode 100644 api/app/core/rag/integrations/feishu/models.py create mode 100644 api/app/core/rag/integrations/feishu/retry.py create mode 100644 api/app/core/rag/integrations/yuque/__init__.py create mode 100644 api/app/core/rag/integrations/yuque/__main__.py create mode 100644 api/app/core/rag/integrations/yuque/client.py create mode 100644 api/app/core/rag/integrations/yuque/exceptions.py create mode 100644 api/app/core/rag/integrations/yuque/models.py create mode 100644 api/app/core/rag/integrations/yuque/retry.py diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 002547f6..db78a368 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -76,6 +76,7 @@ celery_app.conf.update( # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, diff --git a/api/app/controllers/knowledge_controller.py b/api/app/controllers/knowledge_controller.py index 901208ba..01f89a3d 100644 --- a/api/app/controllers/knowledge_controller.py +++ b/api/app/controllers/knowledge_controller.py @@ -9,13 +9,16 @@ from sqlalchemy import or_ from sqlalchemy.orm import Session from app.celery_app import celery_app +from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger from app.core.rag.common import settings +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.yuque.client import YuqueAPIClient from app.core.rag.llm.chat_model import Base from app.core.rag.nlp import rag_tokenizer, search from app.core.rag.prompts.generator import graph_entity_types from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db from app.dependencies import get_current_user from app.models import knowledge_model @@ -484,3 +487,99 @@ async def rebuild_knowledge_graph( except Exception as e: api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}") raise + + +@router.get("/check/yuque/auth", response_model=ApiResponse) +async def check_yuque_auth( + yuque_user_id: str, + yuque_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check yuque auth info + """ + api_logger.info(f"check yuque auth info, username: {current_user.username}") + + try: + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + async with api_client as client: + repos = await client.get_user_repos() + if repos: + return success(data=repos, msg="Successfully auth yuque info") + return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth yuque info failed: {str(e)}") + raise + + +@router.get("/check/feishu/auth", response_model=ApiResponse) +async def check_yuque_auth( + feishu_app_id: str, + feishu_app_secret: str, + feishu_folder_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check feishu auth info + """ + api_logger.info(f"check feishu auth info, username: {current_user.username}") + + try: + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + if files: + return success(data=files, msg="Successfully auth feishu info") + return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth feishu info failed: {str(e)}") + raise + + +@router.post("/{knowledge_id}/sync", response_model=ApiResponse) +async def sync_knowledge( + knowledge_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + sync knowledge base information based on knowledge_id + """ + api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}") + + try: + # 1. Query knowledge base information from the database + api_logger.debug(f"Query knowledge base: {knowledge_id}") + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) + if not db_knowledge: + api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The knowledge base does not exist or access is denied" + ) + + # 2. sync knowledge + # from app.tasks import sync_knowledge_for_kb + # sync_knowledge_for_kb(kb_id) + task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id]) + result = { + "task_id": task.id + } + return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}") + raise diff --git a/api/app/core/rag/crawler/__init__.py b/api/app/core/rag/crawler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/rag/crawler/__main__.py b/api/app/core/rag/crawler/__main__.py new file mode 100644 index 00000000..51a6870f --- /dev/null +++ b/api/app/core/rag/crawler/__main__.py @@ -0,0 +1,89 @@ +"""Command-line interface for web crawler.""" + +import argparse +import logging +import sys +from app.core.rag.crawler.web_crawler import WebCrawler + + +def setup_logging(verbose: bool = False): + """Set up logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] + ) + + +def main(entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0"): + """Main entry point for the crawler.""" + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + + # Crawl and collect documents + documents = [] + try: + for doc in crawler.crawl(): + print(f"\n{'=' * 80}") + print(f"URL: {doc.url}") + print(f"Title: {doc.title}") + print(f"Content Length: {doc.content_length} characters") + print(f"Word Count: {doc.metadata.get('word_count', 0)} words") + print(f"{'=' * 80}\n") + + documents.append({ + 'url': doc.url, + 'title': doc.title, + 'content': doc.content, + 'content_length': doc.content_length, + 'crawl_timestamp': doc.crawl_timestamp.isoformat(), + 'http_status': doc.http_status, + 'metadata': doc.metadata + }) + + except KeyboardInterrupt: + print("\n\nCrawl interrupted by user.") + + except Exception as e: + print(f"\n\nError during crawl: {e}") + sys.exit(1) + + # Get summary + summary = crawler.get_summary() + print(f"\n{'=' * 80}") + print("CRAWL SUMMARY") + print(f"{'=' * 80}") + print(f"Total Pages Processed: {summary.total_pages_processed}") + print(f"Total Errors: {summary.total_errors}") + print(f"Total Skipped: {summary.total_skipped}") + print(f"Total URLs Discovered: {summary.total_urls_discovered}") + print(f"Duration: {summary.duration_seconds:.2f} seconds") + print(f"documents: {documents}") + + if summary.error_breakdown: + print(f"\nError Breakdown:") + for error_type, count in summary.error_breakdown.items(): + print(f" {error_type}: {count}") + + +if __name__ == '__main__': + entry_url = "https://www.xxx.com" + max_pages = 20 + delay_seconds = 1.0 + timeout_seconds = 10 + user_agent = "KnowledgeBaseCrawler/1.0" + + main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent) diff --git a/api/app/core/rag/crawler/content_extractor.py b/api/app/core/rag/crawler/content_extractor.py new file mode 100644 index 00000000..69dca53c --- /dev/null +++ b/api/app/core/rag/crawler/content_extractor.py @@ -0,0 +1,233 @@ +"""Content extractor for web crawler.""" + +from bs4 import BeautifulSoup +import re +import logging + +from app.core.rag.crawler.models import ExtractedContent + +logger = logging.getLogger(__name__) + + +class ContentExtractor: + """Extract clean, readable text from HTML pages.""" + + # Tags to remove completely + REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside'] + + # Tags that typically contain main content + MAIN_CONTENT_TAGS = ['article', 'main'] + + # Content extraction tags + CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section'] + + def is_static_content(self, html: str) -> bool: + """ + Determine if the HTML represents static content. + + Detects JavaScript-rendered content by checking for minimal body + with heavy script tag presence. + + Args: + html: Raw HTML string + + Returns: + bool: True if static, False if JavaScript-rendered + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Count script tags + script_tags = soup.find_all('script') + script_count = len(script_tags) + + # Get body content (excluding scripts and styles) + body = soup.find('body') + if not body: + return False + + # Remove scripts and styles temporarily for text check + for tag in body.find_all(['script', 'style']): + tag.decompose() + + # Get text content + text = body.get_text(strip=True) + text_length = len(text) + + # If there's very little text but many scripts, likely JS-rendered + if script_count > 5 and text_length < 200: + logger.warning("Detected JavaScript-rendered content (many scripts, little text)") + return False + + # If there's no meaningful text, likely JS-rendered + if text_length < 50: + logger.warning("Detected JavaScript-rendered content (minimal text)") + return False + + return True + + except Exception as e: + logger.error(f"Error checking if content is static: {e}") + return True # Assume static on error + + def extract(self, html: str, url: str) -> ExtractedContent: + """ + Extract clean text content from HTML. + + Args: + html: Raw HTML string + url: Source URL (for context) + + Returns: + ExtractedContent: Contains title, text, metadata + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Check if content is static + is_static = self.is_static_content(html) + + # Extract title + title = self._extract_title(soup) + + # Remove unwanted tags + for tag_name in self.REMOVE_TAGS: + for tag in soup.find_all(tag_name): + tag.decompose() + + # Extract main content + text = self._extract_main_content(soup) + + # Normalize whitespace + text = self._normalize_whitespace(text) + + # Count words + word_count = len(text.split()) + + logger.info(f"Extracted {word_count} words from {url}") + + return ExtractedContent( + title=title, + text=text, + is_static=is_static, + word_count=word_count, + metadata={'url': url} + ) + + except Exception as e: + logger.error(f"Error extracting content from {url}: {e}") + return ExtractedContent( + title=url, + text="", + is_static=False, + word_count=0, + metadata={'url': url, 'error': str(e)} + ) + + def _extract_title(self, soup: BeautifulSoup) -> str: + """ + Extract title from HTML. + + Tries tag first, then first <h1>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Page title + """ + # Try <title> tag + title_tag = soup.find('title') + if title_tag and title_tag.string: + return title_tag.string.strip() + + # Try first <h1> + h1_tag = soup.find('h1') + if h1_tag: + return h1_tag.get_text(strip=True) + + # Default to empty string + return "" + + def _extract_main_content(self, soup: BeautifulSoup) -> str: + """ + Extract main content from HTML. + + Prioritizes semantic HTML5 elements like <article> and <main>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Extracted text content + """ + # Try to find main content area + main_content = None + + # Priority 1: <article> or <main> tags + for tag_name in self.MAIN_CONTENT_TAGS: + main_content = soup.find(tag_name) + if main_content: + logger.debug(f"Found main content in <{tag_name}> tag") + break + + # Priority 2: div with role="main" + if not main_content: + main_content = soup.find('div', role='main') + if main_content: + logger.debug("Found main content in div[role='main']") + + # Priority 3: Common class/id patterns + if not main_content: + for pattern in ['content', 'main', 'article', 'post']: + main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with class pattern '{pattern}'") + break + + main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with id pattern '{pattern}'") + break + + # Fallback: use body + if not main_content: + main_content = soup.find('body') + logger.debug("Using <body> as main content (no specific content area found)") + + # Extract text from content tags + if main_content: + text_parts = [] + for tag in main_content.find_all(self.CONTENT_TAGS): + text = tag.get_text(strip=True) + if text: + text_parts.append(text) + + return '\n'.join(text_parts) + + return "" + + def _normalize_whitespace(self, text: str) -> str: + """ + Normalize whitespace in text. + + - Collapse multiple spaces to single space + - Reduce excessive newlines to maximum 2 + - Strip leading/trailing whitespace + + Args: + text: Text to normalize + + Returns: + str: Normalized text + """ + # Collapse multiple spaces to single space + text = re.sub(r' +', ' ', text) + + # Reduce excessive newlines to maximum 2 + text = re.sub(r'\n{3,}', '\n\n', text) + + # Strip leading/trailing whitespace + text = text.strip() + + return text diff --git a/api/app/core/rag/crawler/http_fetcher.py b/api/app/core/rag/crawler/http_fetcher.py new file mode 100644 index 00000000..b3a08098 --- /dev/null +++ b/api/app/core/rag/crawler/http_fetcher.py @@ -0,0 +1,302 @@ +"""HTTP fetcher for web crawler.""" + +import requests +import time +import logging +import re +from typing import Optional, Dict + + +from app.core.rag.crawler.models import FetchResult + +logger = logging.getLogger(__name__) + + +class HTTPFetcher: + """Handle HTTP requests with retries, error handling, and response validation.""" + + def __init__( + self, + timeout: int = 10, + max_retries: int = 3, + user_agent: str = "KnowledgeBaseCrawler/1.0" + ): + """ + Initialize HTTP fetcher. + + Args: + timeout: Request timeout in seconds + max_retries: Maximum number of retry attempts + user_agent: User-Agent header value + """ + self.timeout = timeout + self.max_retries = max_retries + self.user_agent = user_agent + + # Create session for connection pooling + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': user_agent + }) + + def fetch(self, url: str) -> FetchResult: + """ + Fetch a URL with retry logic and error handling. + + Args: + url: URL to fetch + + Returns: + FetchResult: Contains status_code, content, headers, error info + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Calculate backoff delay for retries + if attempt > 0: + backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s + logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s") + time.sleep(backoff_delay) + + # Make HTTP request + response = self.session.get( + url, + timeout=self.timeout, + allow_redirects=True + ) + + # Handle different status codes + if response.status_code == 429: + # Too Many Requests - backoff and retry + logger.warning(f"429 Too Many Requests for {url}, backing off") + if attempt < self.max_retries - 1: + continue + + if response.status_code == 503: + # Service Unavailable - pause and retry + logger.warning(f"503 Service Unavailable for {url}") + if attempt < self.max_retries - 1: + time.sleep(5) # Longer pause for 503 + continue + + # Success or client error (don't retry 4xx except 429) + if 200 <= response.status_code < 300: + logger.info(f"Successfully fetched {url} (status: {response.status_code})") + + # Get correctly encoded content + content = self._get_decoded_content(response) + + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=content, + headers=dict(response.headers), + error=None, + success=True + ) + elif response.status_code == 404: + logger.info(f"404 Not Found: {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error="Not Found", + success=False + ) + elif 400 <= response.status_code < 500: + logger.warning(f"Client error {response.status_code} for {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error=f"Client error: {response.status_code}", + success=False + ) + elif 500 <= response.status_code < 600: + logger.error(f"Server error {response.status_code} for {url}") + last_error = f"Server error: {response.status_code}" + if attempt < self.max_retries - 1: + continue + return FetchResult( + url=url, + final_url=url, + status_code=response.status_code, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.Timeout: + last_error = "Request timeout" + logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.SSLError as e: + last_error = f"SSL/TLS error: {str(e)}" + logger.error(f"SSL/TLS error for {url}: {e}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.ConnectionError as e: + last_error = f"Connection error: {str(e)}" + logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.RequestException as e: + last_error = f"Request error: {str(e)}" + logger.error(f"Request error for {url}: {e}") + if attempt >= self.max_retries - 1: + break + continue + + # All retries exhausted + logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error or "Unknown error", + success=False + ) + + def _get_decoded_content(self, response) -> str: + """ + Get correctly decoded content from response. + + Handles encoding detection and fallback strategies: + 1. Try encoding from HTML meta tags + 2. Try response.encoding (from Content-Type header or detected) + 3. Try UTF-8 + 4. Try common encodings (GB2312, GBK for Chinese, etc.) + 5. Fall back to latin-1 with error replacement + + Args: + response: requests.Response object + + Returns: + str: Decoded content + """ + # Try to detect encoding from HTML meta tags + meta_encoding = self._detect_encoding_from_meta(response.content) + if meta_encoding: + try: + content = response.content.decode(meta_encoding) + logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}") + return content + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}") + + # Try response.encoding (from Content-Type header or detected by requests) + if response.encoding and response.encoding.lower() != 'iso-8859-1': + # Note: requests defaults to ISO-8859-1 if no charset in Content-Type, + # so we skip it here and try UTF-8 first + try: + return response.text + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}") + + # Try UTF-8 first (most common) + try: + return response.content.decode('utf-8') + except UnicodeDecodeError: + logger.debug("UTF-8 decoding failed, trying other encodings") + + # Try common encodings for different languages + encodings_to_try = [ + 'gbk', # Chinese (Simplified) + 'gb2312', # Chinese (Simplified, older) + 'gb18030', # Chinese (Simplified, extended) + 'big5', # Chinese (Traditional) + 'shift_jis', # Japanese + 'euc-jp', # Japanese + 'euc-kr', # Korean + 'iso-8859-1', # Western European + 'windows-1252', # Windows Western European + 'windows-1251', # Cyrillic + ] + + for encoding in encodings_to_try: + try: + content = response.content.decode(encoding) + logger.info(f"Successfully decoded with {encoding}") + return content + except (UnicodeDecodeError, LookupError): + continue + + # Last resort: use latin-1 with error replacement + logger.warning("All encoding attempts failed, using latin-1 with error replacement") + return response.content.decode('latin-1', errors='replace') + + def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]: + """ + Detect encoding from HTML meta tags. + + Looks for: + - <meta charset="..."> + - <meta http-equiv="Content-Type" content="...; charset=..."> + + Args: + content: Raw response content (bytes) + + Returns: + Optional[str]: Detected encoding or None + """ + try: + # Only check first 2KB for performance + head = content[:2048] + + # Try to decode as ASCII/Latin-1 to search for meta tags + try: + head_str = head.decode('ascii', errors='ignore') + except: + head_str = head.decode('latin-1', errors='ignore') + + # Look for <meta charset="..."> + charset_match = re.search( + r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)', + head_str, + re.IGNORECASE + ) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in meta tag: {encoding}") + return encoding + + # Look for <meta http-equiv="Content-Type" content="...; charset=..."> + content_type_match = re.search( + r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)', + head_str, + re.IGNORECASE + ) + if content_type_match: + content_value = content_type_match.group(1) + charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in Content-Type meta: {encoding}") + return encoding + + except Exception as e: + logger.debug(f"Error detecting encoding from meta tags: {e}") + + return None diff --git a/api/app/core/rag/crawler/models.py b/api/app/core/rag/crawler/models.py new file mode 100644 index 00000000..5d10963c --- /dev/null +++ b/api/app/core/rag/crawler/models.py @@ -0,0 +1,52 @@ +"""Data models for web crawler.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, Any, Optional + + +@dataclass +class CrawledDocument: + """Represents a successfully processed web page with extracted content.""" + url: str + title: str + content: str + content_length: int + crawl_timestamp: datetime + http_status: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FetchResult: + """Represents the result of an HTTP fetch operation.""" + url: str + final_url: str + status_code: int + content: Optional[str] + headers: Dict[str, str] + error: Optional[str] + success: bool + + +@dataclass +class ExtractedContent: + """Represents content extracted from HTML.""" + title: str + text: str + is_static: bool + word_count: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CrawlSummary: + """Represents statistics from a completed crawl.""" + total_pages_processed: int + total_errors: int + total_skipped: int + total_urls_discovered: int + start_time: datetime + end_time: datetime + duration_seconds: float + error_breakdown: Dict[str, int] = field(default_factory=dict) diff --git a/api/app/core/rag/crawler/rate_limiter.py b/api/app/core/rag/crawler/rate_limiter.py new file mode 100644 index 00000000..e00fad36 --- /dev/null +++ b/api/app/core/rag/crawler/rate_limiter.py @@ -0,0 +1,57 @@ +"""Rate limiter for web crawler.""" + +import time +import logging + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """Enforce delays between requests to be polite to servers.""" + + def __init__(self, delay_seconds: float = 1.0): + """ + Initialize rate limiter. + + Args: + delay_seconds: Minimum delay between requests + """ + self.delay_seconds = delay_seconds + self.last_request_time = 0.0 + self.max_delay = 60.0 # Cap maximum delay at 60 seconds + + def wait(self): + """ + Block until enough time has passed since last request. + Respects the configured delay. + """ + current_time = time.time() + elapsed = current_time - self.last_request_time + + if elapsed < self.delay_seconds: + sleep_time = self.delay_seconds - elapsed + logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds") + time.sleep(sleep_time) + + self.last_request_time = time.time() + + def set_delay(self, delay_seconds: float): + """ + Update the delay (useful for respecting Crawl-delay from robots.txt). + + Args: + delay_seconds: New delay in seconds + """ + self.delay_seconds = min(delay_seconds, self.max_delay) + logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds") + + def backoff(self, multiplier: float = 2.0): + """ + Increase delay exponentially for backoff scenarios (429, 503 responses). + + Args: + multiplier: Factor to multiply current delay by + """ + old_delay = self.delay_seconds + self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay) + logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s") diff --git a/api/app/core/rag/crawler/robots_parser.py b/api/app/core/rag/crawler/robots_parser.py new file mode 100644 index 00000000..882bc9c8 --- /dev/null +++ b/api/app/core/rag/crawler/robots_parser.py @@ -0,0 +1,118 @@ +"""Robots.txt parser for web crawler.""" + +from urllib.robotparser import RobotFileParser +from urllib.parse import urlparse, urljoin +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +class RobotsParser: + """Parse and check robots.txt compliance for URLs.""" + + def __init__(self, user_agent: str, timeout: int = 10): + """ + Initialize robots.txt parser. + + Args: + user_agent: User agent string to check permissions for + timeout: Timeout for fetching robots.txt + """ + self.user_agent = user_agent + self.timeout = timeout + self._parsers = {} # Cache parsers by domain + + def _get_robots_url(self, url: str) -> str: + """ + Get the robots.txt URL for a given URL. + + Args: + url: URL to get robots.txt for + + Returns: + str: robots.txt URL + """ + parsed = urlparse(url) + robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt" + return robots_url + + def _get_parser(self, url: str) -> RobotFileParser: + """ + Get or create a RobotFileParser for the domain. + + Args: + url: URL to get parser for + + Returns: + RobotFileParser: Parser for the domain + """ + robots_url = self._get_robots_url(url) + + # Return cached parser if available + if robots_url in self._parsers: + return self._parsers[robots_url] + + # Create new parser + parser = RobotFileParser() + parser.set_url(robots_url) + + try: + # Fetch and parse robots.txt + parser.read() + logger.info(f"Successfully fetched robots.txt from {robots_url}") + except Exception as e: + # If robots.txt cannot be fetched, assume all URLs are allowed + logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.") + # Create a permissive parser + parser = RobotFileParser() + parser.parse([]) # Empty robots.txt allows everything + + # Cache the parser + self._parsers[robots_url] = parser + return parser + + def can_fetch(self, url: str) -> bool: + """ + Check if the given URL can be fetched according to robots.txt. + + Args: + url: URL to check + + Returns: + bool: True if allowed, False if disallowed + """ + try: + parser = self._get_parser(url) + allowed = parser.can_fetch(self.user_agent, url) + + if not allowed: + logger.info(f"URL disallowed by robots.txt: {url}") + + return allowed + except Exception as e: + logger.error(f"Error checking robots.txt for {url}: {e}") + # On error, assume allowed + return True + + def get_crawl_delay(self, url: str) -> Optional[float]: + """ + Get the Crawl-delay directive from robots.txt if present. + + Args: + url: URL to get crawl delay for + + Returns: + Optional[float]: Delay in seconds, or None if not specified + """ + try: + parser = self._get_parser(url) + delay = parser.crawl_delay(self.user_agent) + + if delay is not None: + logger.info(f"Crawl-delay from robots.txt: {delay} seconds") + + return delay + except Exception as e: + logger.error(f"Error getting crawl delay for {url}: {e}") + return None diff --git a/api/app/core/rag/crawler/url_normalizer.py b/api/app/core/rag/crawler/url_normalizer.py new file mode 100644 index 00000000..7762a9d5 --- /dev/null +++ b/api/app/core/rag/crawler/url_normalizer.py @@ -0,0 +1,171 @@ +"""URL normalization and validation for web crawler.""" + +from typing import Optional, List +from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin +from bs4 import BeautifulSoup + + +class URLNormalizer: + """Normalize and validate URLs for deduplication and domain checking.""" + + # Common tracking parameters to remove + TRACKING_PARAMS = { + 'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content', + 'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid' + } + + def __init__(self, base_domain: str): + """ + Initialize URL normalizer with base domain. + + Args: + base_domain: The domain to use for same-domain checks + """ + parsed = urlparse(base_domain) + self.base_domain = parsed.netloc.lower() # example.com:8000 + self.base_scheme = parsed.scheme or 'https' # https + + def normalize(self, url: str) -> Optional[str]: + """ + Normalize a URL for deduplication. + + Normalization rules: + 1. Convert domain to lowercase + 2. Remove fragments (#section) + 3. Remove default ports (80 for http, 443 for https) + 4. Remove trailing slashes (except for root) + 5. Sort query parameters alphabetically + 6. Remove common tracking parameters + + Args: + url: URL to normalize + + Returns: + Optional[str]: Normalized URL, or None if invalid + """ + try: + parsed = urlparse(url) + + # Validate scheme + if parsed.scheme not in ('http', 'https'): + return None + + # Normalize domain to lowercase + netloc = parsed.netloc.lower() + + # Remove default ports + if ':' in netloc: + host, port = netloc.rsplit(':', 1) + if (parsed.scheme == 'http' and port == '80') or \ + (parsed.scheme == 'https' and port == '443'): + netloc = host + + # Normalize path + path = parsed.path + # Remove trailing slash except for root + if path != '/' and path.endswith('/'): + path = path.rstrip('/') + # Ensure path starts with / + if not path: + path = '/' + + # Process query parameters + query = '' + if parsed.query: + # Parse query parameters + params = parse_qs(parsed.query, keep_blank_values=True) + # Remove tracking parameters + filtered_params = { + k: v for k, v in params.items() + if k not in self.TRACKING_PARAMS + } + # Sort parameters alphabetically + if filtered_params: + sorted_params = sorted(filtered_params.items()) + query = urlencode(sorted_params, doseq=True) + + # Reconstruct URL without fragment + normalized = urlunparse(( + parsed.scheme, + netloc, + path, + parsed.params, + query, + '' # Remove fragment + )) + + return normalized + + except Exception: + return None + + def is_same_domain(self, url: str) -> bool: + """ + Check if URL belongs to the same domain as base_domain. + + Args: + url: URL to check + + Returns: + bool: True if same domain, False otherwise + """ + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + + # Remove port if present + if ':' in domain: + domain = domain.split(':')[0] + + # Check if domains match + return domain == self.base_domain or domain == self.base_domain.split(':')[0] + + except Exception: + return False + + def extract_links(self, html: str, base_url: str) -> List[str]: + """ + Extract and normalize all links from HTML. + + Args: + html: HTML content + base_url: Base URL for resolving relative links + + Returns: + List[str]: List of normalized absolute URLs + """ + links = [] + + try: + soup = BeautifulSoup(html, 'lxml') + + # Find all anchor tags + for anchor in soup.find_all('a', href=True): + href = anchor['href'] + + # Skip empty hrefs + if not href or href.strip() == '': + continue + + # Skip javascript: and mailto: links + if href.startswith(('javascript:', 'mailto:', 'tel:')): + continue + + normalized_url = None + # Check if href starts with http/https (absolute URL) + if href.startswith(('http://', 'https://')): + if self.is_same_domain(href): + normalized_url = self.normalize(href) + else: + # Convert relative URL to absolute + absolute_url = urljoin(base_url, href) + # Normalize the URL + normalized_url = self.normalize(absolute_url) + + if normalized_url: + links.append(normalized_url) + + except Exception: + pass + + return links diff --git a/api/app/core/rag/crawler/web_crawler.py b/api/app/core/rag/crawler/web_crawler.py new file mode 100644 index 00000000..3afa09b2 --- /dev/null +++ b/api/app/core/rag/crawler/web_crawler.py @@ -0,0 +1,215 @@ +"""Main web crawler orchestrator.""" + +from collections import deque +from datetime import datetime +from typing import Iterator, Optional, List, Set +from urllib.parse import urlparse +import logging + +from app.core.rag.crawler.url_normalizer import URLNormalizer +from app.core.rag.crawler.robots_parser import RobotsParser +from app.core.rag.crawler.rate_limiter import RateLimiter +from app.core.rag.crawler.http_fetcher import HTTPFetcher +from app.core.rag.crawler.content_extractor import ContentExtractor +from app.core.rag.crawler.models import CrawledDocument, CrawlSummary + +logger = logging.getLogger(__name__) + + +class WebCrawler: + """Main orchestrator for web crawling.""" + + def __init__( + self, + entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0", + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + content_extractor: Optional[ContentExtractor] = None + ): + """ + Initialize the web crawler. + + Args: + entry_url: Starting URL for the crawl + max_pages: Maximum number of pages to crawl (default: 200) + delay_seconds: Delay between requests in seconds (default: 1.0) + timeout_seconds: HTTP request timeout (default: 10) + user_agent: User-Agent header string + include_patterns: List of regex patterns for URLs to include + exclude_patterns: List of regex patterns for URLs to exclude + content_extractor: Custom content extractor (optional) + """ + # Validate entry URL + parsed = urlparse(entry_url) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Invalid entry URL: {entry_url}") + + self.entry_url = entry_url + self.max_pages = max_pages + self.user_agent = user_agent + + # Extract domain from entry URL + self.domain = parsed.netloc + + # Initialize components + self.url_normalizer = URLNormalizer(entry_url) + self.robots_parser = RobotsParser(user_agent, timeout_seconds) + self.rate_limiter = RateLimiter(delay_seconds) + self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent) + self.content_extractor = content_extractor or ContentExtractor() + + # State management + self.url_queue: deque = deque() + self.visited_urls: Set[str] = set() + self.pages_processed = 0 + + # Statistics + self.stats = { + 'success': 0, + 'errors': 0, + 'skipped': 0, + 'urls_discovered': 0, + 'error_breakdown': {} + } + self.start_time: Optional[datetime] = None + self.end_time: Optional[datetime] = None + + def crawl(self) -> Iterator[CrawledDocument]: + """ + Execute the crawl and yield documents as they are processed. + + Yields: + CrawledDocument: Structured document with extracted content + """ + logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})") + self.start_time = datetime.now() + + # Add entry URL to queue + normalized_entry = self.url_normalizer.normalize(self.entry_url) + if normalized_entry: + self.url_queue.append(normalized_entry) + self.stats['urls_discovered'] += 1 + + # Check robots.txt and update rate limiter if needed + crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url) + if crawl_delay: + self.rate_limiter.set_delay(crawl_delay) + + # Main crawl loop + while self.url_queue and self.pages_processed < self.max_pages: + url = self.url_queue.popleft() + + # Skip if already visited + if url in self.visited_urls: + continue + + # Mark as visited + self.visited_urls.add(url) + + # Check robots.txt permission + if not self.robots_parser.can_fetch(url): + logger.info(f"Skipping {url} (disallowed by robots.txt)") + self.stats['skipped'] += 1 + continue + + # Apply rate limiting + self.rate_limiter.wait() + + # Fetch URL + logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})") + fetch_result = self.http_fetcher.fetch(url) + + # Handle fetch errors + if not fetch_result.success: + self._record_error(fetch_result.error or "Unknown error") + continue + + # Check Content-Type + content_type = fetch_result.headers.get('Content-Type', '').lower() + if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']): + logger.warning(f"Skipping {url} (Content-Type: {content_type})") + self.stats['skipped'] += 1 + continue + + # Extract content + try: + extracted = self.content_extractor.extract(fetch_result.content, url) + + # Check if static content + if not extracted.is_static: + logger.warning(f"Skipping {url} (JavaScript-rendered content)") + self.stats['skipped'] += 1 + continue + + # Create document + document = CrawledDocument( + url=url, + title=extracted.title, + content=extracted.text, + content_length=len(extracted.text), + crawl_timestamp=datetime.now(), + http_status=fetch_result.status_code, + metadata={ + 'word_count': extracted.word_count, + 'final_url': fetch_result.final_url + } + ) + + # Update statistics + self.pages_processed += 1 + self.stats['success'] += 1 + + # Extract and queue links + links = self.url_normalizer.extract_links(fetch_result.content, url) + for link in links: + if link not in self.visited_urls and self.url_normalizer.is_same_domain(link): + if link not in self.url_queue: + self.url_queue.append(link) + self.stats['urls_discovered'] += 1 + + # Yield document + yield document + + except Exception as e: + logger.error(f"Error processing {url}: {e}") + self._record_error(f"Processing error: {str(e)}") + continue + + self.end_time = datetime.now() + logger.info(f"Crawl completed. Processed {self.pages_processed} pages.") + + def get_summary(self) -> CrawlSummary: + """ + Get summary statistics after crawl completion. + + Returns: + CrawlSummary: Statistics including success/error/skip counts + """ + if not self.start_time: + self.start_time = datetime.now() + if not self.end_time: + self.end_time = datetime.now() + + duration = (self.end_time - self.start_time).total_seconds() + + return CrawlSummary( + total_pages_processed=self.stats['success'], + total_errors=self.stats['errors'], + total_skipped=self.stats['skipped'], + total_urls_discovered=self.stats['urls_discovered'], + start_time=self.start_time, + end_time=self.end_time, + duration_seconds=duration, + error_breakdown=self.stats['error_breakdown'] + ) + + def _record_error(self, error: str): + """Record an error in statistics.""" + self.stats['errors'] += 1 + error_type = error.split(':')[0] if ':' in error else error + self.stats['error_breakdown'][error_type] = \ + self.stats['error_breakdown'].get(error_type, 0) + 1 diff --git a/api/app/core/rag/integrations/__init__.py b/api/app/core/rag/integrations/__init__.py new file mode 100644 index 00000000..c1c43854 --- /dev/null +++ b/api/app/core/rag/integrations/__init__.py @@ -0,0 +1 @@ +"""Integrations package for external services.""" diff --git a/api/app/core/rag/integrations/feishu/__init__.py b/api/app/core/rag/integrations/feishu/__init__.py new file mode 100644 index 00000000..d989b816 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__init__.py @@ -0,0 +1 @@ +"""Feishu integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/feishu/__main__.py b/api/app/core/rag/integrations/feishu/__main__.py new file mode 100644 index 00000000..79d5a48e --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__main__.py @@ -0,0 +1,84 @@ +"""Command-line interface for feishu integration.""" + +import asyncio +import sys +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo + + +def main(feishu_app_id: str, # Feishu application ID + feishu_app_secret: str, # Feishu application secret + feishu_folder_token: str, # Feishu Folder Token + save_dir: str, # save file directory + feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL + timeout: int = 30, # Request timeout in seconds + max_retries: int = 3, # Maximum number of retries + recursive: bool = True # recursive: Whether to sync subfolders recursively, + ): + """Main entry point for the feishuAPIClient.""" + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret, + api_base_url=feishu_api_base_url, + timeout=timeout, + max_retries=max_retries + ) + + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + if recursive: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + else: + all_files = [] + page_token = None + while True: + files_page, page_token = await client.list_folder_files( + feishu_folder_token, page_token + ) + all_files.extend(files_page) + if not page_token: + break + files = all_files + return files + files = asyncio.run(async_get_files(api_client,feishu_folder_token)) + + # Filter out folders, only sync documents + # documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]] + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + + try: + for doc in documents: + print(f"\n{'=' * 80}") + print(f"token: {doc.token}") + print(f"name: {doc.name}") + print(f"type: {doc.type}") + print(f"created_time: {doc.created_time}") + print(f"modified_time: {doc.modified_time}") + print(f"owner_id: {doc.owner_id}") + print(f"url: {doc.url}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == '__main__': + feishu_app_id = "" + feishu_app_secret = "" + feishu_folder_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir) diff --git a/api/app/core/rag/integrations/feishu/client.py b/api/app/core/rag/integrations/feishu/client.py new file mode 100644 index 00000000..0a3c4ea8 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/client.py @@ -0,0 +1,452 @@ +"""Feishu API client for document operations.""" + +import asyncio +import os +import re +from typing import Optional, Tuple, List +from datetime import datetime, timedelta +import httpx +from cachetools import TTLCache +import urllib.parse + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuAPIError, + FeishuNotFoundError, + FeishuPermissionError, + FeishuRateLimitError, + FeishuNetworkError, +) +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.feishu.retry import with_retry + + +class FeishuAPIClient: + """Feishu API client for document synchronization.""" + + def __init__( + self, + app_id: str, + app_secret: str, + api_base_url: str = "https://open.feishu.cn/open-apis", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Feishu API client. + + Args: + app_id: Feishu application ID + app_secret: Feishu application secret + api_base_url: Feishu API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.app_id = app_id + self.app_secret = app_secret + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + self._token_cache: TTLCache = TTLCache(maxsize=1, ttl=7200 - 300) # 2 hours - 5 minutes + self._token_lock = asyncio.Lock() + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={"Content-Type": "application/json"} + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + async def get_tenant_access_token(self) -> str: + """ + Get tenant access token with caching. + + Returns: + Access token string + + Raises: + FeishuAuthError: If authentication fails + """ + # Check cache first + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Use lock to prevent concurrent token requests + async with self._token_lock: + # Double-check cache after acquiring lock + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Request new token + try: + if not self._http_client: + raise FeishuAuthError("HTTP client not initialized") + + response = await self._http_client.post( + "/auth/v3/tenant_access_token/internal", + json={ + "app_id": self.app_id, + "app_secret": self.app_secret + } + ) + + data = response.json() + + if data.get("code") != 0: + error_msg = data.get("msg", "Unknown error") + raise FeishuAuthError( + f"Authentication failed: {error_msg}", + error_code=str(data.get("code")), + details=data + ) + + token = data.get("tenant_access_token") + if not token: + raise FeishuAuthError("No access token in response") + + # Cache the token + self._token_cache["access_token"] = token + + return token + + except httpx.HTTPError as e: + raise FeishuAuthError(f"HTTP error during authentication: {str(e)}") + except Exception as e: + if isinstance(e, FeishuAuthError): + raise + raise FeishuAuthError(f"Unexpected error during authentication: {str(e)}") + + @with_retry + async def list_folder_files( + self, + folder_token: str, + page_token: Optional[str] = None + ) -> Tuple[List[FileInfo], Optional[str]]: + """ + Get list of files in a folder with pagination support. + + Args: + folder_token: Folder token + page_token: Page token for pagination + + Returns: + Tuple of (list of FileInfo, next page token) + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If folder not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Build request parameters + params = {"page_size": 200, "folder_token": folder_token} + if page_token: + params["page_token"] = page_token + + # Make API request + response = await self._http_client.get( + f"/drive/v1/files", + params=params, + headers={"Authorization": f"Bearer {token}"} + ) + + data = response.json() + # print(f"get files: {data}") + + # Handle errors + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + + if error_code == 404 or error_code == 230005: + raise FeishuNotFoundError( + f"Folder not found: {error_msg}", + error_code=str(error_code), + details=data + ) + elif error_code == 403 or error_code == 230003: + raise FeishuPermissionError( + f"Permission denied: {error_msg}", + error_code=str(error_code), + details=data + ) + else: + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + # Parse response + files_data = data.get("data", {}).get("files", []) + next_page_token = data.get("data", {}).get("next_page_token", None) + + # Convert to FileInfo objects + files = [] + for file_data in files_data: + try: + file_info = FileInfo( + token=file_data.get("token", ""), + name=file_data.get("name", ""), + type=file_data.get("type", ""), + created_time=datetime.fromtimestamp(int(file_data.get("created_time", 0))), + modified_time=datetime.fromtimestamp(int(file_data.get("modified_time", 0))), + owner_id=file_data.get("owner_id", ""), + url=file_data.get("url", "") + ) + files.append(file_info) + except (ValueError, TypeError) as e: + # Skip invalid file entries + continue + + return files, next_page_token + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def list_all_folder_files( + self, + folder_token: str, + recursive: bool = True + ) -> List[FileInfo]: + """ + Get all files in a folder, handling pagination automatically. + + Args: + folder_token: Folder token + recursive: Whether to recursively get files from subfolders + + Returns: + List of all FileInfo objects + + Raises: + FeishuAPIError: If API call fails + """ + all_files = [] + page_token = None + + # Get all files with pagination + while True: + files, page_token = await self.list_folder_files(folder_token, page_token) + all_files.extend(files) + + if not page_token: + break + + # Recursively get files from subfolders if requested + if recursive: + subfolders = [f for f in all_files if f.type == "folder"] + for subfolder in subfolders: + try: + subfolder_files = await self.list_all_folder_files( + subfolder.token, + recursive=True + ) + all_files.extend(subfolder_files) + except Exception: + # Continue with other folders if one fails + continue + + return all_files + + @with_retry + async def download_document( + self, + document: FileInfo, + save_dir: str + ) -> str: + """ + download document content. + + Args: + document: Document FileInfo + save_dir: save dir + + Returns: + file_full_path + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If document not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Different API endpoints for different document types + if document.type == "doc" or document.type == "docx" or document.type == "sheet" or document.type == "bitable": + return await self._export_file(document, token, save_dir) + elif document.type == "file" or document.type == "slides": + return await self._download_file(document, token, save_dir) + else: + raise FeishuAPIError(f"Unsupported document type: {document.type}") + + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def _export_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """export file for feishu online file type.""" + try: + # 1.创建导出任务 + file_extension = "pdf" + match document.type: + case "doc": + file_extension = "doc" + case "docx": + file_extension = "docx" + case "sheet": + file_extension = "xlsx" + case "bitable": + file_extension = "xlsx" + case _: + file_extension = "pdf" + response = await self._http_client.post( + "/drive/v1/export_tasks", + json={ + "file_extension": file_extension, + "token": document.token, + "type": document.type + }, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"1.创建导出任务: {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + ticket = data.get("data", {}).get("ticket", None) + if not ticket: + raise FeishuAuthError("No ticket in response") + + # 2.轮序查询导出任务结果 + max_retries = 10 # 最大轮询次数 + poll_interval = 2 # 每次轮询间隔时间(秒) + file_token = None + for attempt in range(max_retries): + # 查询导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/{ticket}", + params={"token": document.token}, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"2. 尝试查询导出任务结果 (第{attempt + 1}次): {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data, + ) + + # 检查导出任务结果 + file_token = data.get("data", {}).get("result", {}).get("file_token", None) + if file_token: + # 如果导出任务成功生成 file_token,则退出轮询 + break + + # 如果结果还没准备好,等待一段时间再进行下一次轮询 + await asyncio.sleep(poll_interval) + + if not file_token: + raise FeishuAPIError("Export task did not complete within the allowed time") + + # 3.下载导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/file/{file_token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + print(f'3.下载导出任务: {response.headers.get("Content-Disposition")}') + + file_full_path = os.path.join(save_dir, document.name + "." + file_extension) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") + + async def _download_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """download file for file type.""" + try: + response = await self._http_client.get( + f"/drive/v1/files/{document.token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + + filename_header = response.headers.get("Content-Disposition") + + # 最终的文件名(初始化为 None) + filename = None + if filename_header: + # 优先解析 filename* 格式 + match = re.search(r"filename\*=([^']*)''([^;]+)", filename_header) + if match: + # 使用 `filename*` 提取(已编码) + encoding = match.group(1) # 编码部分(如 UTF-8) + encoded_filename = match.group(2) # 文件名部分 + filename = urllib.parse.unquote(encoded_filename) # 解码 URL 编码的文件名 + + # 如果 `filename*` 不存在,回退到解析 `filename` + if not filename: + match = re.search(r'filename="([^"]+)"', filename_header) + if match: + filename = match.group(1) + # 如果文件名仍为 None,则使用默认文件名 + if not filename: + filename = f"{document.name}.pdf" + # 确保文件名合法,替换非法字符 + filename = re.sub(r'[\/:*?"<>|]', '_', filename) + + file_full_path = os.path.join(save_dir, filename) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") diff --git a/api/app/core/rag/integrations/feishu/exceptions.py b/api/app/core/rag/integrations/feishu/exceptions.py new file mode 100644 index 00000000..26e42a07 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Feishu integration.""" + + +class FeishuError(Exception): + """Base exception for all Feishu-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class FeishuAuthError(FeishuError): + """Authentication error with Feishu API.""" + pass + + +class FeishuAPIError(FeishuError): + """General API error from Feishu.""" + pass + + +class FeishuNotFoundError(FeishuError): + """Resource not found error (404).""" + pass + + +class FeishuPermissionError(FeishuError): + """Permission denied error (403).""" + pass + + +class FeishuRateLimitError(FeishuError): + """Rate limit exceeded error (429).""" + pass + + +class FeishuNetworkError(FeishuError): + """Network-related error (timeout, connection failure).""" + pass + + +class FeishuDataError(FeishuError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/feishu/models.py b/api/app/core/rag/integrations/feishu/models.py new file mode 100644 index 00000000..b194afc1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/models.py @@ -0,0 +1,17 @@ +"""Data models for Feishu integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, Any, List, Optional + + +@dataclass +class FileInfo: + """File information from Feishu.""" + token: str + name: str + type: str # doc/docx/sheet/bitable/file/slides/folder + created_time: datetime + modified_time: datetime + owner_id: str + url: str diff --git a/api/app/core/rag/integrations/feishu/retry.py b/api/app/core/rag/integrations/feishu/retry.py new file mode 100644 index 00000000..c1d9aff1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/retry.py @@ -0,0 +1,137 @@ +"""Retry strategy for Feishu API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuRateLimitError, + FeishuNetworkError, + FeishuDataError, + FeishuAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + FeishuNetworkError, + FeishuRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for FeishuAPIError with specific codes + if isinstance(error, FeishuAPIError): + if error.error_code: + # Rate limit error codes + if error.error_code in ["99991400", "99991401"]: + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/core/rag/integrations/yuque/__init__.py b/api/app/core/rag/integrations/yuque/__init__.py new file mode 100644 index 00000000..dc4f2a17 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__init__.py @@ -0,0 +1 @@ +"""Yuque integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/yuque/__main__.py b/api/app/core/rag/integrations/yuque/__main__.py new file mode 100644 index 00000000..3b87bbcd --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__main__.py @@ -0,0 +1,77 @@ +"""Main entry point for Yuque integration testing.""" + +import asyncio +import sys +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo + + +def main(yuque_user_id: str, # yuque User ID + yuque_token: str, # yuque Token + save_dir: str, # save file directory + ): + """Main entry point for the YuqueAPIClient.""" + # Create feishuAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + files = asyncio.run(async_get_files(api_client)) + + try: + for doc in files: + print(f"\n{'=' * 80}") + print(f"id: {doc.id}") + print(f"type: {doc.type}") + print(f"slug: {doc.slug}") + print(f"title: {doc.title}") + print(f"book_id: {doc.book_id}") + # print(f"format: {doc.format}") + # print(f"body: {doc.body}") + # print(f"body_draft: {doc.body_draft}") + # print(f"body_html: {doc.body_html}") + print(f"public: {doc.public}") + print(f"status: {doc.status}") + print(f"created_at: {doc.created_at}") + print(f"updated_at: {doc.updated_at}") + print(f"published_at: {doc.published_at}") + print(f"word_count: {doc.word_count}") + print(f"cover: {doc.cover}") + print(f"description: {doc.description}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == "__main__": + yuque_user_id = "" + yuque_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(yuque_user_id, yuque_token, save_dir) diff --git a/api/app/core/rag/integrations/yuque/client.py b/api/app/core/rag/integrations/yuque/client.py new file mode 100644 index 00000000..444d9d31 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/client.py @@ -0,0 +1,544 @@ +"""Yuque API client for document operations.""" + +import os +import re +from typing import Optional, List +from datetime import datetime, timedelta +import httpx +import urllib.parse +import json +from openpyxl import Workbook +from openpyxl.styles import Font, Alignment, PatternFill +from openpyxl.utils import get_column_letter +import zlib + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuqueAPIError, + YuqueNotFoundError, + YuquePermissionError, + YuqueRateLimitError, + YuqueNetworkError, +) +from app.core.rag.integrations.yuque.models import YuqueDocInfo, YuqueRepoInfo +from app.core.rag.integrations.yuque.retry import with_retry + + +class YuqueAPIClient: + """Yuque API client for document synchronization.""" + + def __init__( + self, + user_id: str, + token: str, + api_base_url: str = "https://www.yuque.com/api/v2", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Yuque API client. + + Args: + user_id: Yuque user ID or login name + token: Yuque personal access token + api_base_url: Yuque API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.user_id = user_id + self.token = token + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "X-Auth-Token": self.token, + "User-Agent": "Yuque-Integration-Client" + } + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + def _handle_api_error(self, response: httpx.Response): + """Handle API error responses.""" + try: + data = response.json() + except Exception: + data = {} + + status_code = response.status_code + error_msg = data.get("message", "Unknown error") + + # Rate limit errors + if status_code == 429: + raise YuqueRateLimitError( + f"Rate limit exceeded: {error_msg}", + error_code=str(status_code), + details=data + ) + # Not found errors + elif status_code == 404: + raise YuqueNotFoundError( + f"Resource not found: {error_msg}", + error_code=str(status_code), + details=data + ) + # Permission errors + elif status_code == 403: + raise YuquePermissionError( + f"Permission denied: {error_msg}", + error_code=str(status_code), + details=data + ) + # Authentication errors + elif status_code == 401: + raise YuqueAuthError( + f"Authentication failed: {error_msg}", + error_code=str(status_code), + details=data + ) + # Generic API error + else: + raise YuqueAPIError( + f"API error: {error_msg}", + error_code=str(status_code), + details=data + ) + + @with_retry + async def get_user_repos(self) -> List[YuqueRepoInfo]: + """ + Get all repositories (知识库) for the user. + + Returns: + List of YuqueRepoInfo objects + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/users/{self.user_id}/repos") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + repos_data = data.get("data", []) + + repos = [] + for repo_data in repos_data: + try: + repo = YuqueRepoInfo( + id=repo_data.get("id"), + type=repo_data.get("type", ""), + name=repo_data.get("name", ""), + namespace=repo_data.get("namespace", ""), + slug=repo_data.get("slug", ""), + description=repo_data.get("description"), + public=repo_data.get("public", 0), + items_count=repo_data.get("items_count", 0), + created_at=datetime.fromisoformat(repo_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(repo_data.get("updated_at", "").replace("Z", "+00:00")) + ) + repos.append(repo) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid repo entries + continue + + return repos + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueAuthError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_repo_docs(self, book_id: int) -> List[YuqueDocInfo]: + """ + Get all documents in a repository. + + Args: + book_id: repository id + + Returns: + List of YuqueDocInfo objects (without body content) + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/repos/{book_id}/docs") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + docs_data = data.get("data", []) + + docs = [] + for doc_data in docs_data: + try: + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=None, # Body not included in list API + body_draft=None, + body_html=None, + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + docs.append(doc) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid doc entries + continue + + return docs + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_doc_detail(self, id: int) -> YuqueDocInfo: + """ + Get detailed document information including content. + + Args: + id: document ID + + Returns: + YuqueDocInfo object with full content + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get( + f"/repos/docs/{id}", + params={"raw": 1} # Get raw markdown content + ) + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + doc_data = data.get("data", {}) + + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=doc_data.get("body", ""), + body_draft=doc_data.get("body_draft"), + body_html=doc_data.get("body_html"), + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + + return doc + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + async def download_document( + self, + doc: YuqueDocInfo, + save_dir: str + ) -> str: + """ + Download document content to local file. + + Args: + doc: Document info (can be without body) + save_dir: Directory to save the file + + Returns: + Full path to the saved file + + Raises: + YuqueAPIError: If download fails + """ + try: + # Get full document content if not already loaded + if not doc.body: + doc = await self.get_doc_detail(doc.id) + + # Sanitize filename + filename = re.sub(r'[\/:*?"<>|]', '_', doc.title) + + # Determine file extension based on format + content = doc.body or "" + if doc.format == "markdown": + file_extension = "md" + elif doc.format == "lake": + file_extension = "md" # Save lake format as markdown + elif doc.format == "html": + file_extension = "html" + elif doc.format == "lakesheet": + file_extension = "xlsx" + + body_data = json.loads(doc.body) + sheet_data = body_data.get("sheet", "") + try: + sheet_raw = zlib.decompress(bytes(sheet_data, 'latin-1')) + except Exception as e: + print(f"Error decompressing sheet data: {e}") + raise ValueError("Invalid or unsupported sheet data format.") + try: + sheet_text = sheet_raw.decode("utf-8") # 假设是 UTF-8 编码 + except UnicodeDecodeError: + sheet_text = sheet_raw.decode("gbk") # 如果 UTF-8 解码失败,尝试 GBK + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + self.generate_excel_from_sheet(sheet_text, file_full_path) + return file_full_path + else: + file_extension = "txt" + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + # Remove existing file if it exists + if os.path.exists(file_full_path): + os.remove(file_full_path) + + # Write content to file + with open(file_full_path, "w", encoding="utf-8") as file: + file.write(content) + + return file_full_path + + except Exception as e: + if isinstance(e, YuqueAPIError): + raise + raise YuqueAPIError(f"Unexpected error during file download: {str(e)}") + + def generate_excel_from_sheet(self, sheet_text: str, save_path: str): + """ + 将解析的 sheet_text 数据转换为 Excel 文件。 + + Args: + sheet_text (str): JSON 格式的 sheet 数据。 + save_path (str): Excel 文件的保存路径。 + """ + try: + # 解析 JSON 数据 + sheets = json.loads(sheet_text) + + if not isinstance(sheets, list): + raise ValueError("sheet_text must be a JSON array of sheets.") + + # 创建一个新的 Excel 工作簿 + workbook = Workbook() + + for sheet_index, sheet_data in enumerate(sheets): + sheet_name = sheet_data.get("name", f"Sheet{sheet_index + 1}") + row_data = sheet_data.get("data", {}) + merge_cells = sheet_data.get("mergeCells", {}) + rows_styles = sheet_data.get("rows", []) + cols_styles = sheet_data.get("columns", []) + + # 创建 Sheet + if sheet_index == 0: + worksheet = workbook.active + worksheet.title = sheet_name + else: + worksheet = workbook.create_sheet(title=sheet_name) + + # 设置列宽 + for col_index, col_style in enumerate(cols_styles): + col_width = col_style.get("size", 82.125) / 7.0 + col_letter = get_column_letter(col_index + 1) # Excel 列从1开始 + worksheet.column_dimensions[col_letter].width = col_width + + # 设置行高 + for row_index, row_style in enumerate(rows_styles): + row_height = row_style.get("size", 24) / 1.5 + worksheet.row_dimensions[row_index + 1].height = row_height + + # 写入单元格数据 + for r_index, row in row_data.items(): + for c_index, cell in row.items(): + # 防御性检查:确保行号和列号都是有效的整数 + try: + row_number = int(r_index) + 1 + col_number = int(c_index) + 1 + except ValueError: + print(f"Invalid row or column index: r_index={r_index}, c_index={c_index}") + continue + + if col_number < 1 or col_number > 16384: # Excel 最大列数支持到 XFD,即 16384 列 + print(f"Invalid column index: c_index={c_index}") + continue + + cell_obj = worksheet.cell(row=row_number, column=col_number) + + # 处理值和公式 + cell_value = cell.get("value", "") + if isinstance(cell_value, dict): + # 检查是否为公式 + if cell_value.get("class") == "formula" and "formula" in cell_value: + cell_obj.value = f"={cell_value['formula']}" # 写入公式 + else: + cell_obj.value = cell_value.get("value", "") # 写入值 + else: + cell_obj.value = cell_value # 写入简单值 + + # 应用样式 + style = cell.get("style", {}) + self.apply_cell_style(cell_obj, style) + + # 合并单元格 + for key, merge_def in merge_cells.items(): + start_row = merge_def["row"] + 1 + start_col = merge_def["col"] + 1 + end_row = start_row + merge_def["rowCount"] - 1 + end_col = start_col + merge_def["colCount"] - 1 + worksheet.merge_cells( + start_row=start_row, start_column=start_col, end_row=end_row, end_column=end_col + ) + + # 保存 Excel 文件 + workbook.save(save_path) + print(f"Excel file successfully saved to: {save_path}") + + except Exception as e: + print(f"Error generating Excel file: {e}") + + + def apply_cell_style(self, cell, style): + """ + 应用单元格样式,包括字体、对齐、背景颜色等。 + + Args: + cell: openpyxl 的单元格对象。 + style: 字典格式的样式信息。 + """ + # 定义允许的对齐值 + allowed_horizontal_alignments = {"general", "left", "center", "centerContinuous", "right", "fill", "justify", + "distributed"} + allowed_vertical_alignments = {"top", "center", "justify", "distributed", "bottom"} + + # 处理字体 + font = Font( + size=style.get("fontSize", 11), + bold=style.get("fontWeight", False), + italic=style.get("fontStyle", "normal") == "italic", + underline="single" if style.get("underline", False) else None, + color=self.convert_color_to_hex(style.get("color", "#000000")), + ) + cell.font = font + + # 处理对齐方式 + horizontal_alignment = style.get("hAlign", "left") + vertical_alignment = style.get("vAlign", "top") + + # 如果对齐值无效,则使用默认值 + if horizontal_alignment not in allowed_horizontal_alignments: + horizontal_alignment = "left" + if vertical_alignment not in allowed_vertical_alignments: + vertical_alignment = "top" + + alignment = Alignment( + horizontal=horizontal_alignment, + vertical=vertical_alignment, + wrap_text=style.get("overflow") == "wrap", + ) + cell.alignment = alignment + + # 处理背景颜色 + background_color = style.get("backColor", None) + if background_color: + hex_color = self.convert_color_to_hex(background_color) + if hex_color: + cell.fill = PatternFill( + start_color=hex_color, + end_color=hex_color, + fill_type="solid" + ) + + def convert_color_to_hex(self, color): + """ + 将颜色从 `rgba(...)` 或 `rgb(...)` 转换为 aRGB 十六进制格式。 + + Args: + color (str): 原始颜色字符串,如 `rgba(255,255,0,1.00)` 或 `#FFFFFF`。 + + Returns: + str: 转换后的颜色字符串(符合 openpyxl 的格式),例如 `FFFF0000`。 + """ + try: + if not color: + return None + + # 如果是 `#RRGGBB` 或 `#AARRGGBB` 格式,直接返回 + if color.startswith("#"): + return color.lstrip("#").upper() + + # 如果是 `rgb(...)` 格式,例如 `rgb(255,255,0)` + if color.startswith("rgb("): + rgb_values = color.strip("rgb()").split(",") + red, green, blue = [int(v) for v in rgb_values] + return f"FF{red:02X}{green:02X}{blue:02X}" + + # 如果是 `rgba(...)` 格式,例如 `rgba(255,255,0,1.00)` + if color.startswith("rgba("): + rgba_values = color.strip("rgba()").split(",") + red, green, blue = [int(v) for v in rgba_values[:3]] + alpha = float(rgba_values[3]) + alpha_hex = int(alpha * 255) # 将透明度转换为 [00, FF] + return f"{alpha_hex:02X}{red:02X}{green:02X}{blue:02X}" + + # 返回默认颜色 + return None + except Exception as e: + print(f"Error parsing color '{color}': {e}") + return None diff --git a/api/app/core/rag/integrations/yuque/exceptions.py b/api/app/core/rag/integrations/yuque/exceptions.py new file mode 100644 index 00000000..e862323c --- /dev/null +++ b/api/app/core/rag/integrations/yuque/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Yuque integration.""" + + +class YuqueError(Exception): + """Base exception for all Yuque-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class YuqueAuthError(YuqueError): + """Authentication error with Yuque API.""" + pass + + +class YuqueAPIError(YuqueError): + """General API error from Yuque.""" + pass + + +class YuqueNotFoundError(YuqueError): + """Resource not found error (404).""" + pass + + +class YuquePermissionError(YuqueError): + """Permission denied error (403).""" + pass + + +class YuqueRateLimitError(YuqueError): + """Rate limit exceeded error (429).""" + pass + + +class YuqueNetworkError(YuqueError): + """Network-related error (timeout, connection failure).""" + pass + + +class YuqueDataError(YuqueError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/yuque/models.py b/api/app/core/rag/integrations/yuque/models.py new file mode 100644 index 00000000..6230aa69 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/models.py @@ -0,0 +1,42 @@ +"""Data models for Yuque integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class YuqueRepoInfo: + """Repository (知识库) information from Yuque.""" + id: int # 知识库 ID + type: str # 类型 (Book:文档, Design:图集, Sheet:表格, Resource:资源) + name: str # 名称 + namespace: str # 完整路径: user/repo format + slug: str # 路径 + description: Optional[str] # 简介 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + items_count: int # 文档数量 + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + + +@dataclass +class YuqueDocInfo: + """Document information from Yuque.""" + id: int # 文档 ID + type: str # 文档类型 (Doc:普通文档, Sheet:表格, Thread:话题, Board:图集, Table:数据表) + slug: str # 路径 + title: str # 标题 + book_id: int # 归属知识库 ID + format: str # 内容格式 (markdown:Markdown 格式, lake:语雀 Lake 格式, html:HTML 标准格式, lakesheet:语雀表格) + body: Optional[str] # 正文原始内容 + body_draft: Optional[str] # 正文草稿内容 + body_html: Optional[str] # 正文 HTML 标准格式内容 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + status: int # 状态 (0:草稿, 1:发布) + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + published_at: Optional[datetime] # 发布时间 + word_count: int # 内容字数 + cover: Optional[str] # 封面 + description: Optional[str] # 摘要 diff --git a/api/app/core/rag/integrations/yuque/retry.py b/api/app/core/rag/integrations/yuque/retry.py new file mode 100644 index 00000000..a68d6b47 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/retry.py @@ -0,0 +1,134 @@ +"""Retry strategy for Yuque API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueRateLimitError, + YuqueNetworkError, + YuqueDataError, + YuqueAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + YuqueNetworkError, + YuqueRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for YuqueRateLimitError + if isinstance(error, YuqueRateLimitError): + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/models/file_model.py b/api/app/models/file_model.py index 842e3dc8..44a7d613 100644 --- a/api/app/models/file_model.py +++ b/api/app/models/file_model.py @@ -14,4 +14,5 @@ class File(Base): file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /") file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") file_size = Column(Integer, default=0, comment="file size(byte)") + file_url = Column(String, index=True, nullable=True, comment="file comes from a website url") created_at = Column(DateTime, default=datetime.datetime.now) \ No newline at end of file diff --git a/api/app/models/knowledge_model.py b/api/app/models/knowledge_model.py index 8f0909d3..fbebe1b4 100644 --- a/api/app/models/knowledge_model.py +++ b/api/app/models/knowledge_model.py @@ -57,6 +57,17 @@ class Knowledge(Base): parser_id = Column(String, index=True, default="naive", comment="default parser ID") parser_config = Column(JSON, nullable=False, default={ + "entry_url": "https://ai.redbearai.com", + "max_pages": 20, + "delay_seconds": 1.0, + "timeout_seconds": 10, + "user_agent": "KnowledgeBaseCrawler/1.0", + "yuque_user_id": "User ID", + "yuque_token": "Token", + "feishu_app_id": "App ID", + "feishu_app_secret": "App Secret", + "feishu_folder_token": "Folder Token", + "sync_cron": "30 7 * * 1-5", "layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n", diff --git a/api/app/schemas/file_schema.py b/api/app/schemas/file_schema.py index 00f1a148..7245671a 100644 --- a/api/app/schemas/file_schema.py +++ b/api/app/schemas/file_schema.py @@ -10,6 +10,8 @@ class FileBase(BaseModel): file_name: str file_ext: str file_size: int + file_url: str | None = None + created_at: datetime.datetime | None = None class FileCreate(FileBase): @@ -26,6 +28,7 @@ class FileUpdate(BaseModel): file_name: str | None = Field(None) file_ext: str | None = Field(None) file_size: str | None = Field(None) + file_url: str | None = Field(None) class File(FileBase): diff --git a/api/app/tasks.py b/api/app/tasks.py index a46a3a7b..29b0e485 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -7,6 +7,8 @@ import uuid from uuid import UUID from datetime import datetime, timezone from math import ceil +from pathlib import Path +import shutil from typing import Any, Dict, List, Optional import redis @@ -16,8 +18,13 @@ import trio # Import a unified Celery instance from app.celery_app import celery_app from app.core.config import settings +from app.core.rag.crawler.web_crawler import WebCrawler from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo from app.core.rag.llm.chat_model import Base from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed @@ -29,7 +36,9 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models.document_model import Document +from app.models.file_model import File from app.models.knowledge_model import Knowledge +from app.schemas import file_schema, document_schema from app.services.memory_agent_service import MemoryAgentService @@ -382,6 +391,480 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): db.close() +@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") +def sync_knowledge_for_kb(kb_id: uuid.UUID): + """ + sync knowledge document and Document parsing, vectorization, and storage + """ + db = next(get_db()) # Manually call the generator + db_knowledge = None + try: + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + # 1. get vector_service + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + # 2. sync data + match db_knowledge.type: + case "Web": # Crawl webpages in batches through a web crawler + entry_url = db_knowledge.parser_config.get("entry_url", "") + max_pages = db_knowledge.parser_config.get("max_pages", 20) + delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0) + timeout_seconds = db_knowledge.parser_config.get("timeout_seconds", 10) + user_agent = db_knowledge.parser_config.get("user_agent", "KnowledgeBaseCrawler/1.0") + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + try: + # 初始化存储已爬取 URLs 的集合 + file_urls = set() + # crawl entry_url by yield + for crawled_document in crawler.crawl(): + file_urls.add(crawled_document.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == crawled_document.url).first() + if db_file: + if db_file.file_size == crawled_document.content_length: # same + continue + else: # --update + if crawled_document.content_length: + # 1. update file + db_file.file_name = f"{crawled_document.title}.txt" + db_file.file_ext=".txt" + db_file.file_size=crawled_document.content_length + db.commit() + db.refresh(db_file) + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + if crawled_document.content_length: + # 1. upload file + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=f"{crawled_document.title}.txt", + file_ext=".txt", + file_size=crawled_document.content_length, + file_url=crawled_document.url, + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # Save file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during crawl: {e}") + case "Third-party": # Integration of knowledge bases from three parties + yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") + feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") + if yuque_user_id: # Yuque Knowledge Base + yuque_token = db_knowledge.parser_config.get("yuque_token", "") + # Create yuqueAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + try: + # 初始化存储获取语雀 URLs 的集合 + file_urls = set() + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + + files = asyncio.run(async_get_files(api_client)) + for doc in files: + file_urls.add(doc.slug) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.slug).first() + if db_file: + if db_file.created_at == doc.updated_at: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.updated_at + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.slug, + created_at=doc.updated_at + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + if feishu_app_id: # Feishu Knowledge Base + feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") + feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + try: + # 初始化存储获取飞书 URLs 的集合 + file_urls = set() + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + return files + files = asyncio.run(async_get_files(api_client, feishu_folder_token)) + # Filter out folders, only sync documents + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + for doc in documents: + file_urls.add(doc.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.url).first() + if db_file: + if db_file.created_at == doc.modified_time: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.modified_time + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.url, + created_at = doc.modified_time + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + case _: # General + print(f"General: No synchronization needed\n") + + + result = f"sync knowledge '{db_knowledge.name}' processed successfully." + return result + except Exception as e: + if 'db_knowledge' in locals(): + print(f"Failed to sync knowledge:{str(e)}\n") + result = f"sync knowledge '{db_knowledge.name}' failed." + return result + finally: + db.close() + + @celery_app.task(name="app.core.memory.agent.read_message", bind=True) def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: diff --git a/api/pyproject.toml b/api/pyproject.toml index 6d23a3b9..66b1a295 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -141,6 +141,8 @@ dependencies = [ "flower>=2.0.1", "aiofiles>=23.0.0", "owlready2>=0.46", + "lxml>=4.9.0", + "httpx>=0.28.0", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 6cdae2d1..144c0db2 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -134,3 +134,5 @@ xlrd==2.0.2 oss2>=2.18.0 boto3>=1.28.0 aiofiles>=23.0.0 +lxml>=4.9.0 +httpx>=0.28.0 From db1da4a61ad08c022eaf5f48116a02d2d7e5d5cd Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 6 Feb 2026 12:30:57 +0800 Subject: [PATCH 41/45] Fix/develop memory bug (#339) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix_timeline_memories * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * Multiple independent transactions - single transaction * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id --- api/app/services/app_service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 8a2a0428..38eb5f4c 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -955,11 +955,13 @@ class AppService: ).order_by( AgentConfig.updated_at.desc() ) - config = self.db.scalars(stmt).first() + config = self.db.scalars(stmt).first() + config_memory=config.memory + if 'memory_content' in config_memory: + config.memory['memory_config_id'] = config.memory.pop('memory_content') if config: return config - # 返回默认配置模板(不保存到数据库) logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)}) return self._create_default_agent_config(app_id) From efdd42426e0fda67937691aee6d90611bbed42c1 Mon Sep 17 00:00:00 2001 From: Mark <zhuwenhui5566@163.com> Date: Fri, 6 Feb 2026 12:36:08 +0800 Subject: [PATCH 42/45] [add] migration script --- .../versions/ef0787b85c35_202602061233.py | 32 +++++++++++++++++++ api/uv.lock | 4 +++ 2 files changed, 36 insertions(+) create mode 100644 api/migrations/versions/ef0787b85c35_202602061233.py diff --git a/api/migrations/versions/ef0787b85c35_202602061233.py b/api/migrations/versions/ef0787b85c35_202602061233.py new file mode 100644 index 00000000..1d08ec71 --- /dev/null +++ b/api/migrations/versions/ef0787b85c35_202602061233.py @@ -0,0 +1,32 @@ +"""202602061233 + +Revision ID: ef0787b85c35 +Revises: 9b28b66cf8e8 +Create Date: 2026-02-06 12:33:26.114673 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'ef0787b85c35' +down_revision: Union[str, None] = '9b28b66cf8e8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('files', sa.Column('file_url', sa.String(), nullable=True, comment='file comes from a website url')) + op.create_index(op.f('ix_files_file_url'), 'files', ['file_url'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_files_file_url'), table_name='files') + op.drop_column('files', 'file_url') + # ### end Alembic commands ### diff --git a/api/uv.lock b/api/uv.lock index 587fc5b0..a9bde1ed 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -3224,6 +3224,7 @@ dependencies = [ { name = "hanziconv" }, { name = "html5lib" }, { name = "httptools" }, + { name = "httpx" }, { name = "huggingface-hub" }, { name = "idna" }, { name = "jieba" }, @@ -3237,6 +3238,7 @@ dependencies = [ { name = "langchain-ollama" }, { name = "langchain-openai" }, { name = "langfuse" }, + { name = "lxml" }, { name = "mako" }, { name = "mammoth" }, { name = "markdown" }, @@ -3361,6 +3363,7 @@ requires-dist = [ { name = "hanziconv", specifier = "==0.3.2" }, { name = "html5lib", specifier = "==1.1" }, { name = "httptools", specifier = "==0.7.1" }, + { name = "httpx", specifier = ">=0.28.0" }, { name = "huggingface-hub", specifier = "==0.25.2" }, { name = "idna", specifier = "==3.11" }, { name = "jieba", specifier = ">=0.42.1" }, @@ -3375,6 +3378,7 @@ requires-dist = [ { name = "langchain-ollama" }, { name = "langchain-openai", specifier = ">=1.0.2" }, { name = "langfuse", specifier = ">=3.10.0" }, + { name = "lxml", specifier = ">=4.9.0" }, { name = "mako", specifier = "==1.3.10" }, { name = "mammoth", specifier = "==1.11.0" }, { name = "markdown", specifier = "==3.8" }, From 75f59a86c8f693a93469b886279be000acc6f298 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 6 Feb 2026 13:42:36 +0800 Subject: [PATCH 43/45] Fix/develop memory bug (#341) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix_timeline_memories * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * Multiple independent transactions - single transaction * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id --- api/app/services/app_service.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 38eb5f4c..4583fadb 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -957,11 +957,16 @@ class AppService: ) config = self.db.scalars(stmt).first() - config_memory=config.memory - if 'memory_content' in config_memory: - config.memory['memory_config_id'] = config.memory.pop('memory_content') + + try: + config_memory=config.memory + if 'memory_content' in config_memory: + config.memory['memory_config_id'] = config.memory.pop('memory_content') + except: + logger.debug("记忆配置不存在") if config: return config + # 返回默认配置模板(不保存到数据库) logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)}) return self._create_default_agent_config(app_id) From c566d22836777e95aadcff99d16fb506c3fb2d14 Mon Sep 17 00:00:00 2001 From: zhaoying <yzhao96@best-inc.com> Date: Fri, 6 Feb 2026 13:45:03 +0800 Subject: [PATCH 44/45] fix(web): ui update --- .../views/ModelManagement/components/MultiKeyConfigModal.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx index 5e362025..95ae031f 100644 --- a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx +++ b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx @@ -81,9 +81,9 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod {model.api_keys && model.api_keys.length > 0 && ( <div className="rb:mb-4"> {model.api_keys.map((key) => ( - <div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2"> + <div key={key.id} className="rb:flex rb:gap-3 rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2"> <div className="rb:flex-1"> - <div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div> + <div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium rb:break-all">{key.api_key}</div> <div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div> </div> <Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button> From 8d4c5b5b339d0b56f3134592496125d960cf9833 Mon Sep 17 00:00:00 2001 From: zhaoying <yzhao96@best-inc.com> Date: Fri, 6 Feb 2026 14:03:32 +0800 Subject: [PATCH 45/45] feat(web): memory extraction engine add custom_text --- web/src/api/memory.ts | 2 +- web/src/i18n/en.ts | 3 ++- web/src/i18n/zh.ts | 3 ++- .../MemoryExtractionEngine/components/Result.tsx | 13 ++++++++++++- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 6f4e7f0e..987ef358 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -256,7 +256,7 @@ export const updateMemoryExtractionConfig = (values: ExtractionConfigForm) => { return request.post('/memory-storage/update_config_extracted', values) } // Memory Extraction Engine - Pilot run -export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; }, onMessage?: (data: SSEMessage[]) => void) => { +export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; custom_text?: string; }, onMessage?: (data: SSEMessage[]) => void) => { return handleSSE('/memory-storage/pilot_run', values, onMessage) } // Emotion Engine - Get configuration diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index fe0fbc37..9d706ff6 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1543,7 +1543,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re text_preprocessing_desc: 'Text split into {{count}} semantic fragments', knowledge_extraction_desc: 'Knowledge extraction completed, identified {{entities}} entities, {{statements}} statements, {{temporal_ranges_count}} temporal extractions, {{triplets}} triplets', creating_nodes_edges_desc: 'Entity relationship creation completed, {{num}} relationships in total', - deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total' + deduplication_desc: 'Deduplication and disambiguation completed, {{count}} unique entities in total', + custom_text: 'Debug Text', }, memoryConversation: { searchPlaceholder: 'Enter user ID...', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 7fc8b652..a7ef34ac 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1617,7 +1617,8 @@ export const zh = { text_preprocessing_desc: '文本切分为{{count}}个语义片段', knowledge_extraction_desc: '知识抽取完成,共识别{{entities}}个实体,{{statements}}个句子, {{temporal_ranges_count}}个时间提取, {{triplets}}个三元组', creating_nodes_edges_desc: '实体关系创建完成,共{{num}}条关系', - deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体' + deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体', + custom_text: '调试文本', }, memoryConversation: { chatEmpty:'有什么我可以帮您的吗?', diff --git a/web/src/views/MemoryExtractionEngine/components/Result.tsx b/web/src/views/MemoryExtractionEngine/components/Result.tsx index 6fdeb2af..cb89661a 100644 --- a/web/src/views/MemoryExtractionEngine/components/Result.tsx +++ b/web/src/views/MemoryExtractionEngine/components/Result.tsx @@ -13,7 +13,7 @@ import { type FC, useState } from 'react' import { useParams } from 'react-router-dom' import { useTranslation } from 'react-i18next' -import { Space, Button, Progress } from 'antd' +import { Space, Button, Progress, Form, Input } from 'antd' import { ExclamationCircleFilled, CheckCircleFilled, ClockCircleOutlined, LoadingOutlined } from '@ant-design/icons' import clsx from 'clsx' import type { AnyObject } from 'antd/es/_util/type'; @@ -79,6 +79,8 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => { const [creatingNodesEdges, setCreatingNodesEdges] = useState<ModuleItem>(initObj as ModuleItem) const [deduplication, setDeduplication] = useState<ModuleItem>(initObj as ModuleItem) + const [runForm] = Form.useForm() + /** Run pilot test */ const handleRun = () => { if(!id) return @@ -187,6 +189,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => { pilotRunMemoryExtractionConfig({ config_id: id, dialogue_text: t('memoryExtractionEngine.exampleText'), + custom_text: runForm.getFieldValue('custom_text') }, handleStreamMessage) .finally(() => { setRunLoading(false) @@ -222,6 +225,14 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => { headerClassName="rb:pb-0! rb:pt-4!" bodyClassName="rb:min-h-[calc(100vh-388px)] rb:p-[16px_20px]!" > + <Form form={runForm} layout="vertical"> + <Form.Item + name="custom_text" + label={t('memoryExtractionEngine.custom_text')} + > + <Input.TextArea placeholder={t('common.pleaseEnter')} /> + </Form.Item> + </Form> <div className="rb:min-h-[calc(100vh-480px)] rb:overflow-y-auto"> {runLoading ? <>