diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index e1a1f7f4..c5f3d735 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -196,6 +196,11 @@ def update_config( api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + # 校验至少有一个字段需要更新 + if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index f36aa6c5..588c913c 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -52,6 +52,7 @@ from app.services.ontology_service import OntologyService from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.validation.owl_validator import OWLValidator from app.services.model_service import ModelConfigService +from app.repositories.ontology_scene_repository import OntologySceneRepository api_logger = get_api_logger() @@ -116,27 +117,35 @@ def _get_ontology_service( detail=f"找不到指定的LLM模型: {llm_id}" ) - # 验证模型配置了API密钥 - if not model_config.api_keys: - logger.error(f"Model {llm_id} has no API key configuration") + # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) + from app.repositories.model_repository import ModelApiKeyRepository + api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + if not api_keys: + logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, - detail="指定的LLM模型没有配置API密钥" + detail="指定的LLM模型没有可用的API密钥" ) + api_key_config = api_keys[0] - api_key_config = model_config.api_keys[0] - + is_composite = getattr(model_config, 'is_composite', False) logger.info( f"Using specified model - user: {current_user.id}, " - f"model_id: {llm_id}, model_name: {api_key_config.model_name}" + f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " + f"is_composite: {is_composite}, api_key_id: {api_key_config.id}" ) # 创建模型配置对象 from app.core.models.base import RedBearModelConfig + # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider + actual_provider = api_key_config.provider if is_composite else ( + getattr(model_config, 'provider', None) or "openai" + ) + llm_model_config = RedBearModelConfig( model_name=api_key_config.model_name, - provider=model_config.provider if hasattr(model_config, 'provider') else "openai", + provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, max_retries=3, @@ -648,6 +657,46 @@ async def delete_scene( return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) +@router.get("/scenes/simple", response_model=ApiResponse) +async def get_scenes_simple( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景简单列表(轻量级,用于下拉选择) + + 仅返回 scene_id 和 scene_name,不加载关联数据,响应速度快。 + 适用于前端下拉选择场景的场景。 + + Args: + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含场景简单列表 + + Examples: + GET /scenes/simple + 返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]} + """ + api_logger.info(f"Simple scene list requested by user {current_user.id}") + + try: + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + repo = OntologySceneRepository(db) + scenes = repo.get_simple_list(workspace_id) + + api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes") + return success(data=scenes, msg="查询成功") + + except Exception as e: + api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + @router.get("/scenes", response_model=ApiResponse) async def get_scenes( workspace_id: Optional[str] = None, diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 40cf068e..fae20ea2 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,30 +7,21 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os + import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger -from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.repositories.memory_short_repository import LongTermMemoryRepository from app.services.memory_agent_service import ( get_end_user_connected_config, ) -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool - -from app.utils.config_utils import resolve_config_id - logger = get_business_logger() @@ -289,105 +280,6 @@ class LangChainAgent: return content_parts - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): - db = next(get_db()) - #TODO: 魔法数字 - scope=6 - - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) - - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - - # Handle case where no session exists in Redis (returns False) - if not result or result is False: - logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") - return - - if type=="chunk" or type=="aggregate": - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'写入短长期:') - else: - # TODO: This branch handles type="time" strategy, currently unused. - # Will be activated when time-based long-term storage is implemented. - # TODO: 魔法数字 - extract 5 to a constant - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - logger.debug(f"No recent sessions in Redis for user {end_user_id}") - return - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() - - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -520,14 +412,7 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - long_term_messages=await agent_chat_messages(message_chat,content) - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) response = { "content": content, "model": self.model_name, @@ -710,15 +595,7 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - long_term_messages = await agent_chat_messages(message_chat, full_content) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") - + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index e9de02b6..895f61ac 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -1,8 +1,9 @@ +import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ @@ -10,46 +11,108 @@ from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') +async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) -async def write_messages(end_user_id,langchain_messages,memory_config): - ''' - 写入数据到neo4j: - Args: + Args: + storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - ''' + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + with get_db_context() as db_session: + repo = LongTermMemoryRepository(db_session) + + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - # TODO:删除 - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - print(contents) - except Exception as e: - import traceback - traceback.print_exc() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): ''' @@ -61,25 +124,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope - redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] redis_messages = count_store.get_sessions_count(end_user_id)[1] if is_end_user_id and int(is_end_user_id) != int(scope): - print(is_end_user_id) is_end_user_id += 1 langchain_messages += redis_messages count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): - print('写入长期记忆,并且设置为0') - print(is_end_user_id) - formatted_messages = await chat_data_format(redis_messages) - print(100*'-') - print(formatted_messages) - print(100*'-') - await write_messages(end_user_id, formatted_messages, memory_config) - count_store.update_sessions_count(end_user_id, 0, '') + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) @@ -93,12 +157,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - return - format_messages = await chat_data_format(long_time_data) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message if format_messages!=[]: - await write_messages(end_user_id, format_messages, memory_config) + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) '''聚合判断''' async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ @@ -109,13 +176,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: 内存配置对象 """ - + try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - - # Handle case where no session exists in Redis (returns False or empty) - if not result or result is False: + history = await format_parsing(result) + if not result: history = [] else: history = await format_parsing(result) @@ -154,7 +220,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config } if not structured.is_same_event: logger.info(result_dict) - await write_messages(end_user_id, output_value, memory_config) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) return result_dict except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index c4814de1..fcbb18e3 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): 清理后的数据 """ # 需要过滤的字段列表 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 fields_to_remove = { 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" } if isinstance(data, dict): diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index a1fb8226..9ce581ee 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -1,8 +1,6 @@ import json from langchain_core.messages import HumanMessage, AIMessage - - async def format_parsing(messages: list,type:str='string'): """ 格式化解析消息列表 @@ -26,13 +24,13 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human': + if role == 'human' or role=="user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict": - if role == 'human': + if type == "dict" : + if role == 'human' or role=="user": user.append( content) else: ai.append(content) @@ -57,33 +55,7 @@ async def messages_parse(messages: list | dict): for key, values in zip(user, ai): database.append({key, values}) return database -async def chat_data_format(messages: list | dict): - """ - 将消息格式化为 LangChain 消息格式 - - Args: - messages: 消息列表或字典 - - Returns: - LangChain 消息列表 - """ - langchain_messages = [] - if isinstance(messages, list): - for msg in messages: - if 'role' in msg.keys(): - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - if "Query" in msg.keys(): - langchain_messages.append(HumanMessage(content=msg['Query'])) - langchain_messages.append(AIMessage(content=msg['Answer'])) - if isinstance(messages, dict): - if messages['type'] == 'human': - langchain_messages.append(HumanMessage(content=messages['content'])) - elif messages['type'] == 'ai': - langchain_messages.append(AIMessage(content=messages['content'])) - return langchain_messages + async def agent_chat_messages(user_content,ai_content): messages = [ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 9b858f47..1134acc7 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,13 +1,18 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph +from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_config_service import MemoryConfigService + warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -37,76 +42,61 @@ async def make_write_graph(): yield graph - -async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', - end_user_id: str = '', scope: int = 6): - """Dispatch long-term memory storage to Celery background tasks. - - Args: - long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' - langchain_messages: List of messages to store - memory_config: Memory configuration ID (string) - end_user_id: End user identifier - scope: Window size for 'chunk' strategy (default: 6) - """ - from app.tasks import ( - long_term_storage_window_task, - # TODO: Uncomment when implemented - # long_term_storage_time_task, - # long_term_storage_aggregate_task, - ) - from app.core.logging_config import get_logger - - logger = get_logger(__name__) - - # Convert config to string if needed - config_id = str(memory_config) if memory_config else '' - - if long_term_type == 'chunk': - # Strategy 1: Window-based batching (6 rounds of dialogue) - logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") - long_term_storage_window_task.delay( - end_user_id=end_user_id, - langchain_messages=langchain_messages, - config_id=config_id, - scope=scope +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, (langchain_messages)) + # 获取数据库会话 + with get_db_context() as db_session: + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" ) - # TODO: Uncomment when time-based strategy is fully implemented - # elif long_term_type == 'time': - # # Strategy 2: Time-based retrieval - # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") - # long_term_storage_time_task.delay( - # end_user_id=end_user_id, - # config_id=config_id, - # time_window=5 - # ) - # TODO: Uncomment when aggregate strategy is fully implemented - # elif long_term_type == 'aggregate': - # # Strategy 3: Aggregate judgment (deduplication) - # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") - # long_term_storage_aggregate_task.delay( - # end_user_id=end_user_id, - # langchain_messages=langchain_messages, - # config_id=config_id - # ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) + + + +async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent + from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, aimessages) + await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ # { # "role": "user", -# "content": "今天周五好开心啊" +# "content": "今天周五去爬山" # }, # { # "role": "assistant", -# "content": "你也这么觉得,我也是耶" +# "content": "好耶" # } # # ] # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# # # # if __name__ == "__main__": diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index b61319e5..c5729628 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -294,6 +294,7 @@ class RedisCountStore: """ session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") + index_key = f'session:count:index:{end_user_id}' # 索引键 pipe = self.r.pipeline() pipe.hset(key, mapping={ @@ -304,6 +305,10 @@ class RedisCountStore: "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 + + # 创建索引:end_user_id -> session_id 映射 + pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) + result = pipe.execute() print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") @@ -320,31 +325,47 @@ class RedisCountStore: list 或 False: 如果找到返回 [count, messages],否则返回 False """ try: - search_pattern = 'session:count:*' + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' - for key in self.r.keys(search_pattern): - data = self.r.hgetall(key) - - if not data: - continue - - if data.get('end_user_id') == end_user_id: - count = data.get('count') - messages_str = data.get('messages') - - if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + self.r.delete(index_key) + return False + except Exception as type_error: + print(f"[get_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + return False + + # 直接获取数据 + key = generate_session_key(session_id, key_type="count") + data = self.r.hgetall(key) + + if not data: + # 索引存在但数据不存在,清理索引 + self.r.delete(index_key) + return False + + count = data.get('count') + messages_str = data.get('messages') + + if count is not None: + messages = deserialize_messages(messages_str) + return [int(count), messages] return False except Exception as e: print(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, messages: Any) -> bool: """ - 通过 end_user_id 修改访问次数统计 + 通过 end_user_id 修改访问次数统计(优化版:使用索引) Args: end_user_id: 终端用户ID @@ -355,23 +376,39 @@ class RedisCountStore: bool: 更新成功返回 True,未找到记录返回 False """ try: + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' + + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + # 索引键类型错误,删除并返回 False + print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + self.r.delete(index_key) + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + except Exception as type_error: + print(f"[update_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + + # 直接更新数据 + key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - search_pattern = 'session:count:*' - for key in self.r.keys(search_pattern): - data = self.r.hgetall(key) - - if not data: - continue - - if data.get('end_user_id') == end_user_id: - self.r.hset(key, 'count', int(new_count)) - self.r.hset(key, 'messages', messages_str) - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") - return True + pipe = self.r.pipeline() + pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'messages', messages_str) + result = pipe.execute() + + print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + return True - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") - return False except Exception as e: print(f"[update_sessions_count] 更新失败: {e}") return False diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 76a28156..fadc7669 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline This module provides the main write function for executing the knowledge extraction pipeline. Only MemoryConfig is needed - clients are constructed internally. """ +import asyncio import time from datetime import datetime @@ -124,23 +125,48 @@ async def write( except Exception as e: logger.error(f"Error creating indexes: {e}", exc_info=True) + # 添加死锁重试机制 + max_retries = 3 + retry_delay = 1 # 秒 + + for attempt in range(max_retries): + try: + success = await save_dialog_and_statements_to_neo4j( + dialogue_nodes=all_dialogue_nodes, + chunk_nodes=all_chunk_nodes, + statement_nodes=all_statement_nodes, + entity_nodes=all_entity_nodes, + statement_chunk_edges=all_statement_chunk_edges, + statement_entity_edges=all_statement_entity_edges, + entity_edges=all_entity_entity_edges, + connector=neo4j_connector + ) + if success: + logger.info("Successfully saved all data to Neo4j") + break + else: + logger.warning("Failed to save some data to Neo4j") + if attempt < max_retries - 1: + logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + except Exception as e: + error_msg = str(e) + # 检查是否是死锁错误 + if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower(): + if attempt < max_retries - 1: + logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + else: + logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}") + raise + else: + # 非死锁错误,直接抛出 + raise + try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=all_dialogue_nodes, - chunk_nodes=all_chunk_nodes, - statement_nodes=all_statement_nodes, - entity_nodes=all_entity_nodes, - statement_chunk_edges=all_statement_chunk_edges, - statement_entity_edges=all_statement_entity_edges, - entity_edges=all_entity_entity_edges, - connector=neo4j_connector - ) - if success: - logger.info("Successfully saved all data to Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - finally: await neo4j_connector.close() + except Exception as e: + logger.error(f"Error closing Neo4j connector: {e}") log_time("Neo4j Database Save", time.time() - step_start, log_file) diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 79b88fdc..1880b9ab 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -413,7 +413,8 @@ class ExtractedEntityNode(Node): description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") - fact_summary: str = Field(default="", description="Summary of the fact about this entity") + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index a425e0ed..f2f14d9e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): if len(desc_b) > len(desc_a): canonical.description = desc_b # 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序 - fact_a = getattr(canonical, "fact_summary", "") or "" - fact_b = getattr(ent, "fact_summary", "") or "" - def _extract_sources(txt: str) -> List[str]: - sources: List[str] = [] - if not txt: - return sources - for line in str(txt).splitlines(): - ln = line.strip() + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = getattr(canonical, "fact_summary", "") or "" + # fact_b = getattr(ent, "fact_summary", "") or "" + # def _extract_sources(txt: str) -> List[str]: + # sources: List[str] = [] + # if not txt: + # return sources + # for line in str(txt).splitlines(): + # ln = line.strip() # 支持“来源:”或“来源:”前缀 - m = re.match(r"^来源[::]\s*(.+)$", ln) - if m: - content = m.group(1).strip() - if content: - sources.append(content) + # m = re.match(r"^来源[::]\s*(.+)$", ln) + # if m: + # content = m.group(1).strip() + # if content: + # sources.append(content) # 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失 - if not sources and txt.strip(): - sources.append(txt.strip()) - return sources + # if not sources and txt.strip(): + # sources.append(txt.strip()) + # return sources try: - src_a = _extract_sources(fact_a) - src_b = _extract_sources(fact_b) - seen = set() - merged_sources: List[str] = [] - for s in src_a + src_b: - if s and s not in seen: - seen.add(s) - merged_sources.append(s) - if merged_sources: - name_line = f"实体: {getattr(canonical, 'name', '')}".strip() - canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) - elif fact_b and not fact_a: - canonical.fact_summary = fact_b + # src_a = _extract_sources(fact_a) + # src_b = _extract_sources(fact_b) + # seen = set() + # merged_sources: List[str] = [] + # for s in src_a + src_b: + # if s and s not in seen: + # seen.add(s) + # merged_sources.append(s) + # if merged_sources: + # name_line = f"实体: {getattr(canonical, 'name', '')}".strip() + # canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) + # elif fact_b and not fact_a: + # canonical.fact_summary = fact_b + pass except Exception: # 兜底:若解析失败,保留较长文本 - if len(fact_b) > len(fact_a): - canonical.fact_summary = fact_b + # if len(fact_b) > len(fact_a): + # canonical.fact_summary = fact_b + pass except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 0249ac1f..a028e916 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # # 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整) desc_a = (getattr(a, "description", "") or "") desc_b = (getattr(b, "description", "") or "") - fact_a = (getattr(a, "fact_summary", "") or "") - fact_b = (getattr(b, "fact_summary", "") or "") - score_a = len(desc_a) + len(fact_a) - score_b = len(desc_b) + len(fact_b) + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = (getattr(a, "fact_summary", "") or "") + # fact_b = (getattr(b, "fact_summary", "") or "") + # score_a = len(desc_a) + len(fact_a) + # score_b = len(desc_b) + len(fact_b) + score_a = len(desc_a) + score_b = len(desc_b) if score_a != score_b: return 0 if score_a >= score_b else 1 return 0 @@ -189,7 +192,8 @@ async def _judge_pair( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -197,7 +201,8 @@ async def _judge_pair( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } # 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式) @@ -248,7 +253,8 @@ async def _judge_pair_disamb( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -256,7 +262,8 @@ async def _judge_pair_disamb( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } prompt = render_entity_dedup_prompt( diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index dbc697d9..028a926f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: description=row.get("description") or "", aliases=row.get("aliases") or [], name_embedding=row.get("name_embedding") or [], - fact_summary=row.get("fact_summary") or "", + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=row.get("fact_summary") or "", connect_strength=row.get("connect_strength") or "", ) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 98bec522..08be0aeb 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1088,7 +1088,8 @@ class ExtractionOrchestrator: entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type description=getattr(entity, 'description', ''), # 添加必需的 description 字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段 - fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), diff --git a/api/app/core/memory/utils/alias_utils.py b/api/app/core/memory/utils/alias_utils.py index df75752a..ff139128 100644 --- a/api/app/core/memory/utils/alias_utils.py +++ b/api/app/core/memory/utils/alias_utils.py @@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li key=lambda eid: ( _strength_rank(eid), len(getattr(entity_by_id.get(eid), 'description', '') or ''), - len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + 0 # 临时占位 ), reverse=True ) diff --git a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 index be53c9d4..7fb465a2 100644 --- a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 @@ -9,7 +9,8 @@ - 类型: "{{ entity_a.entity_type | default('') }}" - 描述: "{{ entity_a.description | default('') }}" - 别名: {{ entity_a.aliases | default([]) }} -- 摘要: "{{ entity_a.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_a.connect_strength | default('') }}" 实体B: @@ -17,7 +18,8 @@ - 类型: "{{ entity_b.entity_type | default('') }}" - 描述: "{{ entity_b.description | default('') }}" - 别名: {{ entity_b.aliases | default([]) }} -- 摘要: "{{ entity_b.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_b.connect_strength | default('') }}" 上下文: diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 1f696c98..65fbd9cb 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.llm.chat_model import Base from app.core.rag.llm.embedding_model import OpenAIEmbed +import logging +logger = logging.getLogger(__name__) def knowledge_retrieval( query: str, @@ -62,7 +64,15 @@ def knowledge_retrieval( merge_strategy = config.get("merge_strategy", "weight") reranker_id = config.get("reranker_id") reranker_top_k = config.get("reranker_top_k", 1024) - use_graph = config.get("use_graph", "false").lower() == "true" + # use_graph = config.get("use_graph", "false").lower() == "true" + + use_graph_value = config.get("use_graph", False) + if isinstance(use_graph_value, bool): + use_graph = use_graph_value + elif isinstance(use_graph_value, str): + use_graph = use_graph_value.lower() in ("true", "1", "yes") + else: + use_graph = False file_names_filter = [] if user_ids: @@ -159,13 +169,29 @@ def knowledge_retrieval( # Use the specified reranker for re-ranking if reranker_id: - return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) - # use graph + try: + return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + except Exception as rerank_error: + # If reranker fails, log warning and continue with original results + logger.warning( + "Reranker failed, falling back to original results", + extra={ + "reranker_id": reranker_id, + "query": query, + "doc_count": len(all_results), + "error": str(rerank_error), + }, + ) + if use_graph: - from app.core.rag.common.settings import kg_retriever - doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) - if doc: - all_results.insert(0, doc) + try: + from app.core.rag.common.settings import kg_retriever + doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) + if doc: + all_results.insert(0, doc) + except Exception as graph_error: + print(f"Failed to retrieve from knowledge graph: {str(graph_error)}") + return all_results except Exception as e: diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 475c54fe..31acaafc 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -25,6 +25,18 @@ class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: ParameterExtractorNodeConfig | None = None + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None def _output_types(self) -> dict[str, VariableType]: outputs = {} @@ -180,6 +192,7 @@ class ParameterExtractorNode(BaseNode): ]) model_resp = await llm.ainvoke(messages) + self.response_metadata = model_resp.response_metadata result = json_repair.repair_json(model_resp.content, return_objects=True) logger.info(f"node: {self.node_id} get params:{result}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index d7496f12..38662b64 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -25,6 +25,18 @@ class QuestionClassifierNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None def _output_types(self) -> dict[str, VariableType]: return { @@ -120,6 +132,7 @@ class QuestionClassifierNode(BaseNode): response = await llm.ainvoke(messages) result = response.content.strip() + self.response_metadata = response.response_metadata if result in category_names: category = result diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22972669..68e7cb04 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -86,7 +86,8 @@ class MemoryConfigRepository: n.description AS description, n.entity_type AS entity_type, n.name AS name, - COALESCE(n.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(n.fact_summary, '') AS fact_summary, n.end_user_id AS end_user_id, n.apply_id AS apply_id, n.user_id AS user_id, @@ -279,6 +280,9 @@ class MemoryConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True + if update.scene_id is not None: + db_config.scene_id = update.scene_id + has_update = True if not has_update: raise ValueError("No fields to update") @@ -650,28 +654,32 @@ class MemoryConfigRepository: raise @staticmethod - def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: - """获取所有配置参数 + def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]: + """获取所有配置参数,包含关联的场景名称 Args: db: 数据库会话 workspace_id: 工作空间ID,用于过滤查询结果 Returns: - List[MemoryConfig]: 配置列表 + List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ + from app.models.ontology_scene import OntologyScene + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: - query = db.query(MemoryConfig) + query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin( + OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id + ) if workspace_id: query = query.filter(MemoryConfig.workspace_id == workspace_id) - configs = query.order_by(desc(MemoryConfig.updated_at)).all() + results = query.order_by(desc(MemoryConfig.updated_at)).all() - db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") - return configs + db_logger.debug(f"配置列表查询成功: 数量={len(results)}") + return results except Exception as e: db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") diff --git a/api/app/repositories/neo4j/add_edges.py b/api/app/repositories/neo4j/add_edges.py index 162bf411..2b32551c 100644 --- a/api/app/repositories/neo4j/add_edges.py +++ b/api/app/repositories/neo4j/add_edges.py @@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], try: edges: List[dict] = [] for s in summaries: - for chunk_id in getattr(s, "chunk_ids", []) or []: + chunk_ids = getattr(s, "chunk_ids", []) or [] + for chunk_id in chunk_ids: edges.append({ "summary_id": s.id, "chunk_id": chunk_id, @@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], if not edges: return [] - result = await connector.execute_query( MEMORY_SUMMARY_STATEMENT_EDGE_SAVE, edges=edges ) created = [record.get("uuid") for record in result] if result else [] return created - except Exception: + except Exception as e: return None diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index fcf700b5..42c178b3 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids - except Exception: + except Exception as e: + print(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index cf1732fd..651c513f 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity e.name_embedding = CASE WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding ELSE e.name_embedding END, - e.fact_summary = CASE - WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' - AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) - THEN entity.fact_summary ELSE e.fact_summary END, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // e.fact_summary = CASE + // WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' + // AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) + // THEN entity.fact_summary ELSE e.fact_summary END, e.connect_strength = CASE WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength ELSE CASE @@ -321,7 +322,8 @@ RETURN e.id AS id, e.description AS description, e.aliases AS aliases, e.name_embedding AS name_embedding, - COALESCE(e.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(e.fact_summary, '') AS fact_summary, e.connect_strength AS connect_strength, collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT c.id) AS chunk_ids, @@ -1002,3 +1004,58 @@ RETURN DISTINCT x.statement as statement,x.created_at as created_at """ +Graph_Node_query = """ + MATCH (n:MemorySummary) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Dialogue) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT 1 + + UNION ALL + + MATCH (n:Statement) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:ExtractedEntity) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Chunk) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority + LIMIT $limit + + """ \ No newline at end of file diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index fc32ca9a..526d16ec 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, EntityEntityEdge, ) - +import logging +logger = logging.getLogger(__name__) async def save_entities_and_relationships( entity_nodes: List[ExtractedEntityNode], entity_entity_edges: List[EntityEntityEdge], @@ -41,8 +42,8 @@ async def save_entities_and_relationships( 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, - 'created_at': edge.created_at.isoformat(), - 'expired_at': edge.expired_at.isoformat(), + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, } @@ -147,14 +148,14 @@ async def save_statement_entity_edges( async def save_dialog_and_statements_to_neo4j( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - entity_edges: List[EntityEntityEdge], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + entity_edges: List[EntityEntityEdge], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -171,40 +172,126 @@ async def save_dialog_and_statements_to_neo4j( Returns: bool: True if successful, False otherwise """ - try: - # Save all dialogue nodes in batch - dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) - if dialogue_uuids: + + # 定义事务函数,将所有写操作放在一个事务中 + async def _save_all_in_transaction(tx): + """在单个事务中执行所有保存操作,避免死锁""" + results = {} + + # 1. Save all dialogue nodes in batch + if dialogue_nodes: + from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE + dialogue_data = [node.model_dump() for node in dialogue_nodes] + result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) + dialogue_uuids = [record["uuid"] async for record in result] + results['dialogues'] = dialogue_uuids print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") - else: - print("Failed to save dialogues to Neo4j") - return False - # Save all chunk nodes in batch - await save_chunk_nodes(chunk_nodes, connector) + # 2. Save all chunk nodes in batch + if chunk_nodes: + from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE + chunk_data = [node.model_dump() for node in chunk_nodes] + result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data) + chunk_uuids = [record["uuid"] async for record in result] + results['chunks'] = chunk_uuids + logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") - # Save all statement nodes in batch + # 3. Save all statement nodes in batch if statement_nodes: - statement_uuids = await add_statement_nodes(statement_nodes, connector) - if statement_uuids: - print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - else: - print("Failed to save statement nodes to Neo4j") - return False - else: - print("No statement nodes to save") + from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE + statement_data = [node.model_dump() for node in statement_nodes] + result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data) + statement_uuids = [record["uuid"] async for record in result] + results['statements'] = statement_uuids + logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - # Save entities and relationships - await save_entities_and_relationships(entity_nodes, entity_edges, connector) - print("Successfully saved entities and relationships to Neo4j") + # 4. Save entities + if entity_nodes: + from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE + entity_data = [entity.model_dump() for entity in entity_nodes] + result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data) + entity_uuids = [record["uuid"] async for record in result] + results['entities'] = entity_uuids + logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") - # Save new edges - await save_statement_chunk_edges(statement_chunk_edges, connector) - await save_statement_entity_edges(statement_entity_edges, connector) + # 5. Create entity relationships + if entity_edges: + from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE + relationship_data = [] + for edge in entity_edges: + relationship_data.append({ + 'source_id': edge.source, + 'target_id': edge.target, + 'predicate': edge.relation_type, + 'statement_id': edge.source_statement_id, + 'value': edge.relation_value, + 'statement': edge.statement, + 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, + 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, + 'run_id': edge.run_id, + 'end_user_id': edge.end_user_id, + }) + result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data) + rel_uuids = [record["uuid"] async for record in result] + results['entity_relationships'] = rel_uuids + logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") + # 6. Save statement-chunk edges + if statement_chunk_edges: + from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE + sc_edge_data = [] + for edge in statement_chunk_edges: + sc_edge_data.append({ + "id": edge.id, + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + }) + result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data) + sc_uuids = [record["uuid"] async for record in result] + results['statement_chunk_edges'] = sc_uuids + logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") + + # 7. Save statement-entity edges + if statement_entity_edges: + from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE + se_edge_data = [] + for edge in statement_entity_edges: + se_edge_data.append({ + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + "connect_strength": getattr(edge, "connect_strength", "strong"), + }) + result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data) + se_uuids = [record["uuid"] async for record in result] + results['statement_entity_edges'] = se_uuids + logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + + return results + + try: + # 使用显式写事务执行所有操作,避免死锁 + results = await connector.execute_write_transaction(_save_all_in_transaction) + summary = { + key: len(value) + for key, value in results.items() + if isinstance(value, (list, tuple, set)) + } + logger.info("Transaction completed. Summary: %s", summary) + logger.debug("Full transaction results: %r", results) return True except Exception as e: + logger.error(f"Neo4j integration error: {e}", exc_info=True) print(f"Neo4j integration error: {e}") print("Continuing without database storage...") - return False \ No newline at end of file + return False diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py index 322e111c..141b5d1c 100644 --- a/api/app/repositories/ontology_scene_repository.py +++ b/api/app/repositories/ontology_scene_repository.py @@ -392,3 +392,48 @@ class OntologySceneRepository: exc_info=True ) raise + + def get_simple_list(self, workspace_id: UUID) -> List[dict]: + """获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择) + + 这是一个轻量级查询,不加载关联的classes,响应速度快。 + + Args: + workspace_id: 工作空间ID + + Returns: + List[dict]: 场景简单列表,每项包含scene_id和scene_name + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes = repo.get_simple_list(workspace_id) + >>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...] + """ + try: + logger.debug(f"Getting simple scene list for workspace: {workspace_id}") + + # 只查询需要的字段,不加载关联数据 + results = self.db.query( + OntologyScene.scene_id, + OntologyScene.scene_name + ).filter( + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ).all() + + scenes = [ + {"scene_id": str(r.scene_id), "scene_name": r.scene_name} + for r in results + ] + + logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}") + + return scenes + + except Exception as e: + logger.error( + f"Failed to get simple scene list: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + + diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 11cacda0..c3e7295b 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -248,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 config_id: Union[uuid.UUID, int, str] = None - config_name: str = Field("配置名称", description="配置名称(字符串)") - config_desc: str = Field("配置描述", description="配置描述(字符串)") + config_name: Optional[str] = Field(None, description="配置名称(字符串)") + config_desc: Optional[str] = Field(None, description="配置描述(字符串)") + scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID") class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 31662769..3b301743 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -114,6 +114,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str result = task_service.get_task_memory_read_result(task.id) status = result.get("status") logger.info(f"读取任务状态:{status}") + if memory_content: + memory_content = memory_content['answer'] finally: db.close() @@ -127,11 +129,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "content_length": len(str(memory_content)) } ) - - # 检查是否有有效内容 - if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0: - return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。" - return f"检索到以下历史记忆:\n\n{memory_content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index b7079e62..c9327ccf 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -183,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Read All --- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 - configs = MemoryConfigRepository.get_all(self.db, workspace_id) + results = MemoryConfigRepository.get_all(self.db, workspace_id) # 将 ORM 对象转换为字典列表 data_list = [] - for config in configs: + for config, scene_name in results: # 安全地转换 user_id 为 int config_id_old = None if config.config_id_old: @@ -209,7 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "end_user_id": config.end_user_id, "config_id_old": config_id_old, "apply_id": config.apply_id, - "scene_id": config.scene_id, + "scene_id": str(config.scene_id) if config.scene_id else None, + "scene_name": scene_name, # 新增:场景名称 "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, @@ -637,10 +638,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: if m < 1: latest_relative = "刚刚" elif m < 60: - latest_relative = f"{m}分钟前" + latest_relative = "一会前" else: - h = int(m // 60) - latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前" + latest_relative = "较早前" except Exception: pass diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index f454359c..716f74ec 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.conversation_repository import ConversationRepository from app.repositories.end_user_repository import EndUserRepository +from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.services.implicit_memory_service import ImplicitMemoryService @@ -1521,7 +1522,6 @@ async def analytics_graph_data( user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) - if not end_user: logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") return { @@ -1575,21 +1575,11 @@ async def analytics_graph_data( } else: # 查询所有节点 - node_query = """ - MATCH (n) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) as id, - labels(n)[0] as label, - properties(n) as properties - LIMIT $limit - """ + node_query=Graph_Node_query node_params = { "end_user_id": end_user_id, "limit": limit } - - # 执行节点查询 node_results = await _neo4j_connector.execute_query(node_query, **node_params) @@ -1600,9 +1590,9 @@ async def analytics_graph_data( for record in node_results: node_id = record["id"] - node_label = record["label"] + node_labels = record.get("labels", []) + node_label = node_labels[0] if node_labels else "Unknown" node_props = record["properties"] - # 根据节点类型提取需要的属性字段 filtered_props = await _extract_node_properties(node_label, node_props,node_id) diff --git a/api/app/utils/config_utils.py b/api/app/utils/config_utils.py index cc67afd2..55cfe8a3 100644 --- a/api/app/utils/config_utils.py +++ b/api/app/utils/config_utils.py @@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports. """ from uuid import UUID from sqlalchemy.orm import Session +import uuid as uuid_module -def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID: +def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID: """ - 解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID + 解析 config_id,支持 UUID、UUID字符串、整数等多种格式 Args: - config_id: 配置ID(UUID 或整数) + config_id: 配置ID(UUID、UUID字符串 或 整数) db: 数据库会话 Returns: UUID: 解析后的配置ID Raises: - ValueError: 当找不到对应的配置时 + ValueError: 当找不到对应的配置时或格式无效时 """ - from app.models.memory_config_model import MemoryConfig - if isinstance(config_id, UUID): + + # 1. 如果已经是 UUID 类型,直接返回 + if isinstance(config_id, UUID): return config_id - if isinstance(config_id, str) and len(config_id)<=6: - memory_config = db.query(MemoryConfig).filter( - MemoryConfig.config_id_old == int(config_id) - ).first() - print(memory_config) - if not memory_config: - raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置") - return memory_config.config_id + + # 2. 如果是字符串类型 + if isinstance(config_id, str): + config_id_stripped = config_id.strip() + + # 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36) + try: + return uuid_module.UUID(config_id_stripped) + except ValueError: + pass + + # 2.2 尝试解析为整数(用于查询 config_id_old) + try: + old_id = int(config_id_stripped) + if old_id > 0: + memory_config = db.query(MemoryConfig).filter( + MemoryConfig.config_id_old == old_id + ).first() + if not memory_config: + raise ValueError(f"未找到 config_id_old={old_id} 对应的配置") + return memory_config.config_id + except ValueError: + pass + + # 2.3 无法解析的字符串格式 + raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)") + + # 3. 如果是整数类型,通过 config_id_old 查找 if isinstance(config_id, int): + if config_id <= 0: + raise ValueError(f"config_id 必须是正整数: {config_id}") + memory_config = db.query(MemoryConfig).filter( MemoryConfig.config_id_old == config_id ).first() if not memory_config: - raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置") + raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id - return config_id + # 4. 不支持的类型 + raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}") diff --git a/web/package.json b/web/package.json index e28e8b56..89800fcf 100644 --- a/web/package.json +++ b/web/package.json @@ -13,6 +13,14 @@ "@antv/layout": "^1.2.14-beta.8", "@antv/x6": "^3.0.1", "@antv/x6-react-shape": "^3.0.1", + "@codemirror/lang-cpp": "^6.0.3", + "@codemirror/lang-java": "^6.0.2", + "@codemirror/lang-javascript": "^6.2.4", + "@codemirror/lang-python": "^6.2.1", + "@codemirror/lang-rust": "^6.0.2", + "@codemirror/state": "^6.5.4", + "@codemirror/theme-one-dark": "^6.1.3", + "@codemirror/view": "^6.39.12", "@dnd-kit/core": "^6.3.1", "@dnd-kit/modifiers": "^9.0.0", "@dnd-kit/sortable": "^10.0.0", @@ -25,6 +33,7 @@ "antd": "^5.27.4", "axios": "^1.12.2", "clsx": "^2.1.1", + "codemirror": "^6.0.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", "dayjs": "^1.11.18", @@ -55,6 +64,7 @@ "@tailwindcss/postcss": "^4.1.14", "@tailwindcss/typography": "^0.5.19", "@tailwindcss/vite": "^4.1.14", + "@types/codemirror": "^5.60.17", "@types/crypto-js": "^4.2.2", "@types/js-yaml": "^4.0.9", "@types/node": "^24.6.0", diff --git a/web/src/api/ontology.ts b/web/src/api/ontology.ts index becf899f..90a6857f 100644 --- a/web/src/api/ontology.ts +++ b/web/src/api/ontology.ts @@ -8,6 +8,7 @@ import { request } from '@/utils/request' import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData, OntologyExportModalData } from '@/views/Ontology/types' // Scene list +export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple' export const getOntologyScenesUrl = '/memory/ontology/scenes' export const getOntologyScenesList = (data: Query) => { return request.get(getOntologyScenesUrl, data) diff --git a/web/src/components/CodeMirrorEditor/index.tsx b/web/src/components/CodeMirrorEditor/index.tsx new file mode 100644 index 00000000..e100b75b --- /dev/null +++ b/web/src/components/CodeMirrorEditor/index.tsx @@ -0,0 +1,150 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-04 17:20:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 17:20:52 + */ +import { useEffect, useRef, useMemo } from 'react'; +import { EditorView, basicSetup } from 'codemirror'; +import { EditorState } from '@codemirror/state'; +import { python } from '@codemirror/lang-python'; +import { javascript } from '@codemirror/lang-javascript'; +import { java } from '@codemirror/lang-java'; +import { cpp } from '@codemirror/lang-cpp'; +import { rust } from '@codemirror/lang-rust'; +import { oneDark } from '@codemirror/theme-one-dark'; + +/** + * Props for the CodeMirrorEditor component + * @property {string} value - The initial code content to display in the editor + * @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust) + * @property {function} onChange - Callback function triggered when editor content changes, receives the new code value + * @property {string} theme - Editor theme, either 'light' or 'dark' + * @property {boolean} readOnly - Whether the editor is read-only + * @property {string} height - Custom height for the editor + * @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font) + */ +interface CodeMirrorEditorProps { + value?: string; + language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust'; + onChange?: (value: string) => void; + theme?: 'light' | 'dark'; + readOnly?: boolean; + height?: string; + size?: 'default' | 'small'; +} + +/** + * Map of language identifiers to their corresponding CodeMirror language extensions + * Supports multiple programming languages with syntax highlighting + */ +const languageExtensions: Record = { + python: python(), + python3: python(), + javascript: javascript(), + typescript: javascript({ typescript: true }), + java: java(), + cpp: cpp(), + c: cpp(), + rust: rust(), +}; + +/** + * CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor + * Provides a code editor with syntax highlighting, theme support, and customizable sizing + * Used in workflow code execution nodes for editing Python and JavaScript code + */ +const CodeMirrorEditor = ({ + value = '', + language = 'javascript', + onChange, + theme = 'light', + readOnly = false, + size, +}: CodeMirrorEditorProps) => { + // Reference to the DOM element that will contain the editor + const editorRef = useRef(null); + // Reference to the CodeMirror EditorView instance + const viewRef = useRef(null); + + /** + * Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes + * Sets up extensions for syntax highlighting, change listeners, and theme + */ + useEffect(() => { + if (!editorRef.current) return; + + // Get the appropriate language extension, fallback to JavaScript if not found + const langExtension = languageExtensions[language] || languageExtensions.javascript; + + // Configure editor extensions + const extensions = [ + basicSetup, // Basic editor features (line numbers, bracket matching, etc.) + langExtension, // Language-specific syntax highlighting + // Listen for document changes and trigger onChange callback + EditorView.updateListener.of((update) => { + if (update.docChanged && onChange) { + onChange(update.state.doc.toString()); + } + }), + EditorState.readOnly.of(readOnly), // Set read-only mode + ]; + + // Apply dark theme if specified + if (theme === 'dark') { + extensions.push(oneDark); + } + + // Create editor state with initial value and extensions + const state = EditorState.create({ + doc: value, + extensions, + }); + + // Create and mount the editor view + viewRef.current = new EditorView({ + state, + parent: editorRef.current, + }); + + // Cleanup: destroy editor instance when component unmounts or dependencies change + return () => { + viewRef.current?.destroy(); + }; + }, [language, theme, readOnly]); + + /** + * Update editor content when the value prop changes externally + * Only updates if the new value differs from current editor content + */ + useEffect(() => { + if (viewRef.current && value !== viewRef.current.state.doc.toString()) { + viewRef.current.dispatch({ + changes: { + from: 0, + to: viewRef.current.state.doc.length, + insert: value, + }, + }); + } + }, [value]); + + // Calculate minimum height based on size prop: small (60px) or default (120px) + const minHeight = useMemo(() => { + return `${size === 'small' ? 60 : 120}px` + }, [size]) + + // Calculate font size based on size prop: small (12px) or default (14px) + const fontSize = useMemo(() => { + return `${size === 'small' ? 12 : 14}px` + }, [size]) + + // Calculate line height based on size prop: small (16px) or default (20px) + const lineHeight = useMemo(() => { + return `${size === 'small' ? 16 : 20}px` + }, [size]) + + return
; +}; + +export default CodeMirrorEditor; diff --git a/web/src/components/Markdown/index.tsx b/web/src/components/Markdown/index.tsx index a2fac5ba..9d3c482b 100644 --- a/web/src/components/Markdown/index.tsx +++ b/web/src/components/Markdown/index.tsx @@ -81,7 +81,7 @@ const components = { audio: ({ src, ...props }: any) => , a: ({ href, children, ...props }: any) => {children}, button: ({ children }: any) => {[children]}, - table: ({ children, ...props }: any) => {children}
, + table: ({ children, ...props }: any) =>
{children}
, tr: ({ children, ...props }: any) => {children}, th: ({ children, ...props }: any) => {children}, td: ({ children, ...props }: any) => {children}, diff --git a/web/src/styles/index.css b/web/src/styles/index.css index bbbe9cd9..d937396a 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -180,4 +180,9 @@ body { .x6-node foreignObject > body { min-height: 100%; max-height: 100%; +} + +.ͼ2 .cm-gutters { + background-color: #FFFFFF; + border: none; } \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx index 297e9faa..7fdf1ab2 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx @@ -140,7 +140,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi title={t('application.knowledgeBaseAssociation')} extra={ - + } diff --git a/web/src/views/MemoryManagement/components/MemoryForm.tsx b/web/src/views/MemoryManagement/components/MemoryForm.tsx index 22bff65a..93246ca9 100644 --- a/web/src/views/MemoryManagement/components/MemoryForm.tsx +++ b/web/src/views/MemoryManagement/components/MemoryForm.tsx @@ -16,7 +16,7 @@ import { useTranslation } from 'react-i18next'; import type { MemoryFormData, Memory, MemoryFormRef } from '../types'; import RbModal from '@/components/RbModal' import { createMemoryConfig, updateMemoryConfig } from '@/api/memory' -import { getOntologyScenesUrl } from '@/api/ontology' +import { getOntologyScenesSimpleUrl } from '@/api/ontology' import CustomSelect from '@/components/CustomSelect'; const FormItem = Form.Item; @@ -129,8 +129,7 @@ const MemoryForm = forwardRef(({ > { title={item.config_name} > -
{item.config_desc}
+
{item.config_desc}
diff --git a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx index 8a21e012..169d9690 100644 --- a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx +++ b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx @@ -103,9 +103,9 @@ const MultiKeyConfigModal = forwardRef 0 && (
{model.api_keys.map((key) => ( -
-
-
{key.api_key}
+
+
+
{key.api_key}
{key.api_base}
diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index 4c8540a8..60da03a7 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -15,8 +15,6 @@ import CharacterCountPlugin from './plugin/CharacterCountPlugin' import InitialValuePlugin from './plugin/InitialValuePlugin'; import CommandPlugin from './plugin/CommandPlugin'; import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin'; -import Python3HighlightPlugin from './plugin/Python3HighlightPlugin'; -import JavaScriptHighlightPlugin from './plugin/JavaScriptHighlightPlugin'; import LineNumberPlugin from './plugin/LineNumberPlugin'; import BlurPlugin from './plugin/BlurPlugin'; import { VariableNode } from './nodes/VariableNode' @@ -32,7 +30,7 @@ export interface LexicalEditorProps { lineHeight?: number; size?: 'default' | 'small'; type?: 'input' | 'textarea', - language?: 'string' | 'jinja2' | 'python3' | 'javascript' + language?: 'string' | 'jinja2' } const theme = { @@ -67,7 +65,7 @@ const Editor: FC =({ const [enableLineNumbers, setEnableLineNumbers] = useState(false) useEffect(() => { - const needsLineNumbers = language === 'jinja2' || language === 'python3' || language === 'javascript'; + const needsLineNumbers = language === 'jinja2'; setEnableJinja2(language === 'jinja2'); setEnableLineNumbers(needsLineNumbers); @@ -237,13 +235,11 @@ const Editor: FC =({ {language === 'jinja2' && } - {language === 'python3' && } - {language === 'javascript' && } {enableLineNumbers && } { setCount(count) }} onChange={onChange} /> - - {enableLineNumbers && } + + {enableJinja2 && }
); diff --git a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx deleted file mode 100644 index 21219139..00000000 --- a/web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx +++ /dev/null @@ -1,182 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; - -const JS_KEYWORDS = new Set([ - 'async', 'await', 'break', 'case', 'catch', 'class', 'const', 'continue', 'debugger', 'default', - 'delete', 'do', 'else', 'export', 'extends', 'finally', 'for', 'function', 'if', 'import', - 'in', 'instanceof', 'let', 'new', 'return', 'super', 'switch', 'this', 'throw', 'try', - 'typeof', 'var', 'void', 'while', 'with', 'yield', 'true', 'false', 'null', 'undefined' -]); - -const JavaScriptHighlightPlugin = () => { - const [editor] = useLexicalComposerContext(); - const isPastingRef = useRef(false); - - useEffect(() => { - return editor.registerCommand( - PASTE_COMMAND, - () => { - isPastingRef.current = true; - setTimeout(() => { - isPastingRef.current = false; - }, 100); - return false; - }, - COMMAND_PRIORITY_LOW - ); - }, [editor]); - - useEffect(() => { - return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { - if (isPastingRef.current) return; - - const text = textNode.getTextContent(); - - if (textNode.hasFormat('code')) return; - if (!needsHighlight(text)) return; - if (textNode.getStyle()) return; - - const parent = textNode.getParent(); - if (!parent) return; - - const selection = $getSelection(); - let selectionOffset = null; - if ($isRangeSelection(selection)) { - const anchor = selection.anchor; - if (anchor.getNode() === textNode) { - selectionOffset = anchor.offset; - } - } - - const tokens = tokenizeJavaScript(text); - if (tokens.length <= 1) return; - - const newNodes = tokens.map(token => { - const newNode = $createTextNode(token.text); - newNode.toggleFormat('code'); - - switch (token.type) { - case 'keyword': - newNode.setStyle('color: #d73a49; font-weight: 600;'); - break; - case 'string': - newNode.setStyle('color: #032f62;'); - break; - case 'comment': - newNode.setStyle('color: #6a737d; font-style: italic;'); - break; - case 'number': - newNode.setStyle('color: #005cc5; font-weight: 500;'); - break; - case 'function': - newNode.setStyle('color: #6f42c1; font-weight: 500;'); - break; - } - - return newNode; - }); - - if (newNodes.length > 1) { - textNode.replace(newNodes[0]); - for (let i = 1; i < newNodes.length; i++) { - newNodes[i - 1].insertAfter(newNodes[i]); - } - - if (selectionOffset !== null && $isRangeSelection(selection)) { - let currentOffset = 0; - for (const node of newNodes) { - const nodeLength = node.getTextContent().length; - if (currentOffset + nodeLength >= selectionOffset) { - node.select(selectionOffset - currentOffset, selectionOffset - currentOffset); - break; - } - currentOffset += nodeLength; - } - } - } - }); - }, [editor]); - - return null; -}; - -function needsHighlight(text: string): boolean { - return /[a-zA-Z0-9_/"'`]/.test(text); -} - -function tokenizeJavaScript(text: string): Array<{text: string, type: string}> { - const tokens: Array<{text: string, type: string}> = []; - let i = 0; - - while (i < text.length) { - // Single-line comments - if (text.slice(i, i + 2) === '//') { - let start = i; - while (i < text.length && text[i] !== '\n') i++; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Multi-line comments - if (text.slice(i, i + 2) === '/*') { - let start = i; - i += 2; - while (i < text.length && text.slice(i, i + 2) !== '*/') i++; - if (i < text.length) i += 2; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Strings - if (text[i] === '"' || text[i] === "'" || text[i] === '`') { - const quote = text[i]; - let start = i++; - - while (i < text.length) { - if (text[i] === quote && text[i - 1] !== '\\') { - i++; - break; - } - i++; - } - tokens.push({ text: text.slice(start, i), type: 'string' }); - continue; - } - - // Numbers - if (/\d/.test(text[i])) { - let start = i; - while (i < text.length && /[\d.]/.test(text[i])) i++; - tokens.push({ text: text.slice(start, i), type: 'number' }); - continue; - } - - // Keywords and identifiers - if (/[a-zA-Z_$]/.test(text[i])) { - let start = i; - while (i < text.length && /[a-zA-Z0-9_$]/.test(text[i])) i++; - const word = text.slice(start, i); - - if (JS_KEYWORDS.has(word)) { - tokens.push({ text: word, type: 'keyword' }); - } else if (i < text.length && text[i] === '(') { - tokens.push({ text: word, type: 'function' }); - } else { - tokens.push({ text: word, type: 'text' }); - } - continue; - } - - // Other characters - let start = i; - while (i < text.length && !/[a-zA-Z0-9_$/"'`]/.test(text[i])) i++; - if (start < i) { - tokens.push({ text: text.slice(start, i), type: 'text' }); - } - } - - return tokens; -} - -export default JavaScriptHighlightPlugin; diff --git a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx deleted file mode 100644 index 12830ffb..00000000 --- a/web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx +++ /dev/null @@ -1,177 +0,0 @@ -import { useEffect, useRef } from 'react'; -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import { TextNode, $createTextNode, $getSelection, $isRangeSelection, COMMAND_PRIORITY_LOW, PASTE_COMMAND } from 'lexical'; - -const PYTHON_KEYWORDS = new Set([ - 'False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue', - 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', - 'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', - 'with', 'yield' -]); - -const Python3HighlightPlugin = () => { - const [editor] = useLexicalComposerContext(); - const isPastingRef = useRef(false); - - useEffect(() => { - return editor.registerCommand( - PASTE_COMMAND, - () => { - isPastingRef.current = true; - setTimeout(() => { - isPastingRef.current = false; - }, 100); - return false; - }, - COMMAND_PRIORITY_LOW - ); - }, [editor]); - - useEffect(() => { - return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { - if (isPastingRef.current) return; - - const text = textNode.getTextContent(); - - if (textNode.hasFormat('code')) return; - if (textNode.getStyle()) return; - if (!needsHighlight(text)) return; - - const parent = textNode.getParent(); - if (!parent) return; - - const selection = $getSelection(); - let selectionOffset = null; - if ($isRangeSelection(selection)) { - const anchor = selection.anchor; - if (anchor.getNode() === textNode) { - selectionOffset = anchor.offset; - } - } - - const tokens = tokenizePython(text); - if (tokens.length <= 1) return; - - const newNodes = tokens.map(token => { - const newNode = $createTextNode(token.text); - newNode.toggleFormat('code'); - - switch (token.type) { - case 'keyword': - newNode.setStyle('color: #d73a49; font-weight: 600;'); - break; - case 'string': - newNode.setStyle('color: #032f62;'); - break; - case 'comment': - newNode.setStyle('color: #6a737d; font-style: italic;'); - break; - case 'number': - newNode.setStyle('color: #005cc5; font-weight: 500;'); - break; - case 'function': - newNode.setStyle('color: #6f42c1; font-weight: 500;'); - break; - } - - return newNode; - }); - - if (newNodes.length > 1) { - textNode.replace(newNodes[0]); - for (let i = 1; i < newNodes.length; i++) { - newNodes[i - 1].insertAfter(newNodes[i]); - } - - if (selectionOffset !== null && $isRangeSelection(selection)) { - let currentOffset = 0; - for (const node of newNodes) { - const nodeLength = node.getTextContent().length; - if (currentOffset + nodeLength >= selectionOffset) { - node.select(selectionOffset - currentOffset, selectionOffset - currentOffset); - break; - } - currentOffset += nodeLength; - } - } - } - }); - }, [editor]); - - return null; -}; - -function needsHighlight(text: string): boolean { - return /[a-zA-Z0-9_#"']/.test(text); -} - -function tokenizePython(text: string): Array<{text: string, type: string}> { - const tokens: Array<{text: string, type: string}> = []; - let i = 0; - - while (i < text.length) { - // Comments - if (text[i] === '#') { - let start = i; - while (i < text.length && text[i] !== '\n') i++; - tokens.push({ text: text.slice(start, i), type: 'comment' }); - continue; - } - - // Strings - if (text[i] === '"' || text[i] === "'") { - const quote = text[i]; - let start = i++; - const isTriple = text.slice(start, start + 3) === quote.repeat(3); - if (isTriple) i += 2; - - while (i < text.length) { - if (isTriple && text.slice(i, i + 3) === quote.repeat(3)) { - i += 3; - break; - } else if (!isTriple && text[i] === quote && text[i - 1] !== '\\') { - i++; - break; - } - i++; - } - tokens.push({ text: text.slice(start, i), type: 'string' }); - continue; - } - - // Numbers - if (/\d/.test(text[i])) { - let start = i; - while (i < text.length && /[\d.]/.test(text[i])) i++; - tokens.push({ text: text.slice(start, i), type: 'number' }); - continue; - } - - // Keywords and identifiers - if (/[a-zA-Z_]/.test(text[i])) { - let start = i; - while (i < text.length && /[a-zA-Z0-9_]/.test(text[i])) i++; - const word = text.slice(start, i); - - if (PYTHON_KEYWORDS.has(word)) { - tokens.push({ text: word, type: 'keyword' }); - } else if (i < text.length && text[i] === '(') { - tokens.push({ text: word, type: 'function' }); - } else { - tokens.push({ text: word, type: 'text' }); - } - continue; - } - - // Other characters - let start = i; - while (i < text.length && !/[a-zA-Z0-9_#"']/.test(text[i])) i++; - if (start < i) { - tokens.push({ text: text.slice(start, i), type: 'text' }); - } - } - - return tokens; -} - -export default Python3HighlightPlugin; diff --git a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx index 8a0ea03e..b9c2c881 100644 --- a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx +++ b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx @@ -5,8 +5,8 @@ import { Node } from '@antv/x6' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import MappingList from '../MappingList' -import Editor from '../../Editor' import OutputList from './OutputList' +import CodeMirrorEditor from '@/components/CodeMirrorEditor'; interface MappingItem { name?: string @@ -110,7 +110,10 @@ const CodeExecution: FC = ({ options }) => { prev.language !== curr.language}> {() => ( - + )} diff --git a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx index da9603c8..3cd7efcd 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx @@ -126,7 +126,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
- {t('workflow.config.knowledge-retrieval.recallConfig')} + {t('application.globalConfig')}