diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 2300f148..f55ea5b5 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -60,14 +60,14 @@ def list_apps( """ workspace_id = current_user.current_workspace_id service = app_service.AppService(db) - + # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: app_ids = [id.strip() for id in ids.split(',') if id.strip()] items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] return success(data=items) - + # 正常分页查询 items_orm, total = app_service.list_apps( db, diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index 913874f1..1a2e3cbc 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -620,34 +620,52 @@ class AccessHistoryManager: new_version = current_version + 1 # 步骤2:使用乐观锁更新节点 - # 只有当版本号匹配时才更新 - update_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - """ + # 根据节点类型构建完整的查询语句 + content_field_map = { + 'Statement': 'n.statement as statement', + 'MemorySummary': 'n.content as content', + 'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 + } + + # 显式检查节点类型,不支持的类型抛出错误 + if node_label not in content_field_map: + raise ValueError( + f"Unsupported node_label: {node_label}. " + f"Supported labels are: {list(content_field_map.keys())}" + ) + + content_field = content_field_map[node_label] + + # 构建 WHERE 子句 + where_conditions = [] if group_id: - update_query += " WHERE n.group_id = $group_id" + where_conditions.append("n.group_id = $group_id") # 添加版本检查 if current_version > 0: - update_query += " AND n.version = $current_version" + where_conditions.append("n.version = $current_version") else: - # 如果节点没有版本号,检查是否为首次更新 - update_query += " AND (n.version IS NULL OR n.version = 0)" + where_conditions.append("(n.version IS NULL OR n.version = 0)") - update_query += """ + where_clause = " AND ".join(where_conditions) if where_conditions else "true" + + # 构建完整的更新查询 + update_query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + WHERE {where_clause} SET n.activation_value = $activation_value, n.access_history = $access_history, n.last_access_time = $last_access_time, n.access_count = $access_count, n.version = $new_version RETURN n.id as id, - n.statement as statement, n.activation_value as activation_value, n.access_history as access_history, n.last_access_time as last_access_time, n.access_count as access_count, n.importance_score as importance_score, - n.version as version + n.version as version, + {content_field} """ update_params = { @@ -671,7 +689,11 @@ class AccessHistoryManager: f"Expected version {current_version}, but node was modified by another transaction." ) - return dict(updated_node) + # 转换为字典并移除占位符字段 + result_dict = dict(updated_node) + result_dict.pop('content_placeholder', None) + + return result_dict # 执行事务 try: diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index e3d634d8..67689935 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -3,13 +3,11 @@ 基于 LangGraph 的工作流执行引擎。 """ - -# import uuid import datetime import logging +import uuid from typing import Any -from langchain_core.messages import HumanMessage from langgraph.graph.state import CompiledStateGraph from app.core.workflow.graph_builder import GraphBuilder @@ -55,6 +53,12 @@ class WorkflowExecutor: self.edges = workflow_config.get("edges", []) self.execution_config = workflow_config.get("execution_config", {}) + self.checkpoint_config = { + "configurable": { + "thread_id": uuid.uuid4(), + } + } + def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: """准备初始状态(注入系统变量和会话变量) @@ -95,7 +99,7 @@ class WorkflowExecutor: case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: conversation_vars[var_name] = [] input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 - + conversation_vars = conversation_vars | input_data.get("conv", {}) # 构建分层的变量结构 variables = { "sys": { @@ -110,7 +114,7 @@ class WorkflowExecutor: } return { - "messages": [HumanMessage(content=user_message)], + "messages": [('user', user_message)], "variables": variables, "node_outputs": {}, "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) @@ -196,6 +200,28 @@ class WorkflowExecutor: logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") return prefixes, adjacent_and_referenced + def _build_final_output(self, result, elapsed_time): + node_outputs = result.get("node_outputs", {}) + final_output = self._extract_final_output(node_outputs) + token_usage = self._aggregate_token_usage(node_outputs) + conversation_id = None + for node_id, node_output in node_outputs.items(): + if node_output.get("node_type") == "start": + conversation_id = node_output.get("output", {}).get("conversation_id") + break + + return { + "status": "completed", + "output": final_output, + "node_outputs": node_outputs, + "messages": result.get("messages", []), + "conversation_id": conversation_id, + "elapsed_time": elapsed_time, + "token_usage": token_usage, + "error": result.get("error"), + "variables": result.get("variables", {}), + } + def build_graph(self, stream=False) -> CompiledStateGraph: """构建 LangGraph @@ -236,40 +262,16 @@ class WorkflowExecutor: # 3. 执行工作流 try: - result = await graph.ainvoke(initial_state) + + result = await graph.ainvoke(initial_state, config=self.checkpoint_config) # 计算耗时 end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - # 提取节点输出(现在包含 start 和 end 节点) - node_outputs = result.get("node_outputs", {}) - - # 提取最终输出(从最后一个非 start/end 节点) - final_output = self._extract_final_output(node_outputs) - - # 聚合 token 使用情况 - token_usage = self._aggregate_token_usage(node_outputs) - - # 提取 conversation_id(从 start 节点输出) - conversation_id = None - for node_id, node_output in node_outputs.items(): - if node_output.get("node_type") == "start": - conversation_id = node_output.get("output", {}).get("conversation_id") - break - logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") - return { - "status": "completed", - "output": final_output, - "node_outputs": node_outputs, - "messages": result.get("messages", []), - "conversation_id": conversation_id, - "elapsed_time": elapsed_time, - "token_usage": token_usage, - "error": result.get("error") - } + return self._build_final_output(result, elapsed_time) except Exception as e: # 计算耗时(即使失败也记录) @@ -331,11 +333,11 @@ class WorkflowExecutor: # 3. Execute workflow try: chunk_count = 0 - final_state = None async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode + config=self.checkpoint_config ): # event should be a tuple: (mode, data) # But let's handle both cases @@ -411,12 +413,11 @@ class WorkflowExecutor: elif mode == "updates": # Handle state updates - store final state logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") - final_state = data # 计算耗时 end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - + result = graph.get_state(self.checkpoint_config).values logger.info( f"Workflow execution completed (streaming), " f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s" @@ -425,12 +426,7 @@ class WorkflowExecutor: # 发送 workflow_end 事件 yield { "event": "workflow_end", - "data": { - "execution_id": self.execution_id, - "status": "completed", - "elapsed_time": elapsed_time, - "timestamp": end_time.isoformat() - } + "data": self._build_final_output(result, elapsed_time) } except Exception as e: diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index 69ed3b6a..b75b867e 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -4,6 +4,7 @@ from typing import Any from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.graph import START, END +from langgraph.checkpoint.memory import InMemorySaver from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory @@ -249,4 +250,5 @@ class GraphBuilder: self.graph = StateGraph(WorkflowState) self.add_nodes() self.add_edges() # 添加边必须在添加节点之后 - return self.graph.compile() + checkpointer = InMemorySaver() + return self.graph.compile(checkpointer=checkpointer) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 7b9d645b..96f68ce8 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class AssignerNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = AssignerNodeConfig(**self.config) + self.typed_config: AssignerNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: """ @@ -28,6 +28,7 @@ class AssignerNode(BaseNode): None or the result of the assignment operation. """ # Initialize a variable pool for accessing conversation, node, and system variables + self.typed_config = AssignerNodeConfig(**self.config) logger.info(f"节点 {self.node_id} 开始执行") pool = VariablePool(state) for assignment in self.typed_config.assignments: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 727f7391..e3bf36c9 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -25,7 +25,7 @@ class WorkflowState(TypedDict): The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ # List of messages (append mode) - messages: Annotated[list[AnyMessage], add] + messages: Annotated[list[tuple[str, str]], add] # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list @@ -203,6 +203,7 @@ class BaseNode(ABC): # 返回包装后的输出和运行时变量 return { **wrapped_output, + "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var }, @@ -355,6 +356,7 @@ class BaseNode(ABC): # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, + "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var }, diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index fb062f39..1659395e 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None self.cycle_nodes = list() # Nodes belonging to this cycle self.cycle_edges = list() # Edges connecting nodes within the cycle diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 2e5de796..141cba79 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = HttpRequestNodeConfig(**self.config) + self.typed_config: HttpRequestNodeConfig | None = None def _build_timeout(self) -> Timeout: """ @@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode): - dict: Serialized HttpRequestNodeOutput on success - str: Branch identifier (e.g. "ERROR") when branching is enabled """ + self.typed_config = HttpRequestNodeConfig(**self.config) async with httpx.AsyncClient( verify=self.typed_config.verify_ssl, timeout=self._build_timeout(), diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 8c6d222f..41f1138b 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) class IfElseNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = IfElseNodeConfig(**self.config) + self.typed_config: IfElseNodeConfig | None= None @staticmethod def _evaluate(operator, instance: CompareOperatorInstance) -> Any: @@ -109,6 +109,7 @@ class IfElseNode(BaseNode): Returns: str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. """ + self.typed_config = IfElseNodeConfig(**self.config) expressions = self.evaluate_conditional_edge_expressions(state) # TODO: 变量类型及文本类型解析 for i in range(len(expressions)): diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 70993573..822f1918 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) class JinjaRenderNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = JinjaRenderNodeConfig(**self.config) + self.typed_config: JinjaRenderNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: """ @@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode): RuntimeError: If Jinja2 template rendering fails due to invalid template syntax or missing variables. """ + self.typed_config = JinjaRenderNodeConfig(**self.config) render = TemplateRenderer(strict=False) context = {} diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index 9d307216..5475636e 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -44,8 +44,8 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): description="Knowledge base config" ) - reranker_id: UUID = Field( - default="", + reranker_id: UUID | None = Field( + default=None, description="Reranker top k" ) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 061328e1..221ca079 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + self.typed_config: KnowledgeRetrievalNodeConfig | None = None @staticmethod def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): @@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode): Raises: RuntimeError: If no valid knowledge base is found or access is denied. """ + self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) query = self._render_template(self.typed_config.query, state) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 5fb86ae2..6395d3b8 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -68,7 +68,7 @@ class LLMNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = LLMNodeConfig(**self.config) + self.typed_config: LLMNodeConfig | None = None def _render_context(self, message, state): context = f"{self._render_template(self.typed_config.context, state)}" @@ -164,6 +164,7 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ + self.typed_config = LLMNodeConfig(**self.config) llm, prompt_or_messages = self._prepare_llm(state, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 0d1b1fb4..f1c99ddb 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService class MemoryReadNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = MemoryReadNodeConfig(**self.config) + self.typed_config: MemoryReadNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: + self.typed_config = MemoryReadNodeConfig(**self.config) with get_db_read() as db: workspace_id = self.get_variable('sys.workspace_id', state) end_user_id = self.get_variable("sys.user_id", state) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 84d61aa9..ec58d96c 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = ParameterExtractorNodeConfig(**self.config) + self.typed_config: ParameterExtractorNodeConfig | None = None @staticmethod def _get_prompt(): @@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode): Raises: BusinessException: If LLM output cannot be parsed as valid JSON. """ + self.typed_config = ParameterExtractorNodeConfig(**self.config) llm = self._get_llm_instance() system_prompt, user_prompt = self._get_prompt() diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index b0f2c28d..aee72eda 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = QuestionClassifierNodeConfig(**self.config) - self.category_to_case_map = self._build_category_case_map() + self.typed_config: QuestionClassifierNodeConfig | None = None + self.category_to_case_map = {} def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" @@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode): async def execute(self, state: WorkflowState) -> dict: """执行问题分类""" + self.typed_config = QuestionClassifierNodeConfig(**self.config) + self.category_to_case_map = self._build_category_case_map() question = self.typed_config.input_variable supplement_prompt = self.typed_config.user_supplement_prompt or "" categories = self.typed_config.categories or [] diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 7c3a2fca..69560422 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -7,6 +7,7 @@ Start 节点实现 import logging from typing import Any +from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.start.config import StartNodeConfig @@ -34,7 +35,7 @@ class StartNode(BaseNode): super().__init__(node_config, workflow_config) # 解析并验证配置 - self.typed_config = StartNodeConfig(**self.config) + self.typed_config: StartNodeConfig | None = None async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行 start 节点业务逻辑 @@ -47,6 +48,7 @@ class StartNode(BaseNode): Returns: 包含系统参数、会话变量和自定义变量的字典 """ + self.typed_config = StartNodeConfig(**self.config) logger.info(f"节点 {self.node_id} (Start) 开始执行") # 创建变量池实例(在方法内复用) @@ -113,6 +115,18 @@ class StartNode(BaseNode): logger.debug( f"变量 '{var_name}' 使用默认值: {var_def.default}" ) + else: + match var_def.type: + case VariableType.STRING: + processed[var_name] = "" + case VariableType.NUMBER: + processed[var_name] = 0 + case VariableType.OBJECT: + processed[var_name] = {} + case VariableType.BOOLEAN: + processed[var_name] = False + case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: + processed[var_name] = [] return processed diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 496812a6..3e79b075 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -19,10 +19,11 @@ class ToolNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = ToolNodeConfig(**self.config) + self.typed_config: ToolNodeConfig | None = None async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行工具""" + self.typed_config = ToolNodeConfig(**self.config) # 获取租户ID和用户ID tenant_id = self.get_variable("sys.tenant_id", state) user_id = self.get_variable("sys.user_id", state) diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index e6cbf75b..5bff8e33 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) class VariableAggregatorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = VariableAggregatorNodeConfig(**self.config) + self.typed_config: VariableAggregatorNodeConfig | None = None @staticmethod def _get_express(variable_string: str) -> Any: @@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode): - str: In non-group mode, returns the first non-None variable value. - dict: In group mode, returns a mapping of group_name -> first non-None variable value. """ + self.typed_config = VariableAggregatorNodeConfig(**self.config) if not self.typed_config.group: # -------------------------- # Non-group mode diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 1549ef86..80756793 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -66,24 +66,38 @@ async def _update_activation_values_batch( max_retries=max_retries ) - # 提取节点ID列表 - node_ids = [node.get('id') for node in nodes if node.get('id')] + # 提取节点ID列表并去重(保持原始顺序) + seen_ids = set() + unique_node_ids = [] + for node in nodes: + node_id = node.get('id') + if node_id and node_id not in seen_ids: + seen_ids.add(node_id) + unique_node_ids.append(node_id) - if not node_ids: + if not unique_node_ids: logger.warning(f"批量更新激活值:没有有效的节点ID") return nodes + + # 记录去重信息(仅针对具有有效 ID 的节点) + id_nodes_count = sum(1 for n in nodes if n.get("id")) + if len(unique_node_ids) < id_nodes_count: + logger.info( + f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, " + f"去重后唯一ID数量={len(unique_node_ids)}" + ) # 批量记录访问 try: updated_nodes = await access_manager.record_batch_access( - node_ids=node_ids, + node_ids=unique_node_ids, node_label=node_label, group_id=group_id ) logger.info( f"批量更新激活值成功: {node_label}, " - f"更新数量={len(updated_nodes)}/{len(node_ids)}" + f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}" ) return updated_nodes @@ -153,19 +167,38 @@ async def _update_search_results_activation( original_nodes = results[key] updated_nodes = update_result - # 创建 ID 到原始节点的映射(用于快速查找 score) - original_map = {node.get('id'): node for node in original_nodes if node.get('id')} + # 创建 ID 到更新节点的映射(用于快速查找激活值数据) + updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')} - # 合并数据:激活值来自更新结果,score 来自原始结果 + # 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充 merged_nodes = [] - for updated_node in updated_nodes: - node_id = updated_node.get('id') - if node_id and node_id in original_map: - # 保留原始的 score 字段 - original_score = original_map[node_id].get('score') - if original_score is not None: - updated_node['score'] = original_score - merged_nodes.append(updated_node) + for original_node in original_nodes: + node_id = original_node.get('id') + if node_id and node_id in updated_map: + # 从原始节点开始,用更新后的激活值数据覆盖 + merged_node = original_node.copy() + + # 更新激活值相关字段 + activation_fields = { + 'activation_value', + 'access_history', + 'last_access_time', + 'access_count', + 'importance_score', + 'version', + 'statement', # Statement 节点的内容字段 + 'content' # MemorySummary 节点的内容字段 + } + + # 只更新激活值相关字段,保留原始节点的其他字段 + for field in activation_fields: + if field in updated_map[node_id]: + merged_node[field] = updated_map[node_id][field] + + merged_nodes.append(merged_node) + else: + # 如果没有更新数据,保留原始节点 + merged_nodes.append(original_node) updated_results[key] = merged_nodes else: diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index acea60b7..569684d5 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session +from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger @@ -22,6 +23,7 @@ from app.core.rag.nlp.search import knowledge_retrieval from app.models import AgentConfig, ModelApiKey, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message +from app.services import task_service from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger @@ -101,6 +103,14 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str user_rag_memory_id=user_rag_memory_id ) ) + task = celery_app.send_task( + "app.core.memory.agent.read_message", + args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] + ) + result = task_service.get_task_memory_read_result(task.id) + status = result.get("status") + logger.info(f"读取任务状态:{status}") + finally: db.close() logger.info(f'用户ID:Agent:{end_user_id}') diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 68d6279b..7d3c784f 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -491,6 +491,17 @@ class WorkflowService: ) end_user_id = str(new_end_user.id) + executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) + + for exec_res in executions: + if exec_res.status == "completed": + last_state = exec_res.output_data + if isinstance(last_state, dict): + variables = last_state.get("variables", {}) + conv_vars = variables.get("conv", {}) + input_data["conv"] = conv_vars + break + result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, @@ -504,7 +515,7 @@ class WorkflowService: self.update_execution_status( execution.execution_id, "completed", - output_data=result.get("node_outputs", {}) + output_data=result ) else: self.update_execution_status( @@ -517,6 +528,7 @@ class WorkflowService: return { "execution_id": execution.execution_id, "status": result.get("status"), + "variables": result.get("variables"), "output": result.get("output"), # 最终输出(字符串) "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID @@ -617,6 +629,16 @@ class WorkflowService: original_user_id=payload.user_id # Save original user_id to other_id ) end_user_id = str(new_end_user.id) + executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) + + for exec_res in executions: + if exec_res.status == "completed": + last_state = exec_res.output_data + if isinstance(last_state, dict): + variables = last_state.get("variables", {}) + conv_vars = variables.get("conv", {}) + input_data["conv"] = conv_vars + break # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) async for event in self._run_workflow_stream( @@ -827,6 +849,23 @@ class WorkflowService: user_id=user_id ): # 直接转发事件(executor 已经返回正确格式) + if event.get("event") == "workflow_end": + + status = event.get("data", {}).get("status") + if status == "completed": + self.update_execution_status( + execution_id, + "completed", + output_data=event.get("data") + ) + elif status == "failed": + self.update_execution_status( + execution_id, + "failed", + output_data=event.get("data") + ) + else: + logger.error(f"unexpect workflow run status, status: {status}") yield event except Exception as e: