Merge branch 'feature/20260105_xjn' into feature/agent-tool_xjn
This commit is contained in:
@@ -60,14 +60,14 @@ def list_apps(
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
db,
|
||||
|
||||
@@ -620,34 +620,52 @@ class AccessHistoryManager:
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 只有当版本号匹配时才更新
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
# 根据节点类型构建完整的查询语句
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||
}
|
||||
|
||||
# 显式检查节点类型,不支持的类型抛出错误
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if group_id:
|
||||
update_query += " WHERE n.group_id = $group_id"
|
||||
where_conditions.append("n.group_id = $group_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
update_query += " AND n.version = $current_version"
|
||||
where_conditions.append("n.version = $current_version")
|
||||
else:
|
||||
# 如果节点没有版本号,检查是否为首次更新
|
||||
update_query += " AND (n.version IS NULL OR n.version = 0)"
|
||||
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||
|
||||
update_query += """
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||
|
||||
# 构建完整的更新查询
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE {where_clause}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.statement as statement,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version
|
||||
n.version as version,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
@@ -671,7 +689,11 @@ class AccessHistoryManager:
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
return dict(updated_node)
|
||||
# 转换为字典并移除占位符字段
|
||||
result_dict = dict(updated_node)
|
||||
result_dict.pop('content_placeholder', None)
|
||||
|
||||
return result_dict
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
|
||||
@@ -3,13 +3,11 @@
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
# import uuid
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
@@ -55,6 +53,12 @@ class WorkflowExecutor:
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.checkpoint_config = {
|
||||
"configurable": {
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
|
||||
@@ -95,7 +99,7 @@ class WorkflowExecutor:
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
conversation_vars[var_name] = []
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
conversation_vars = conversation_vars | input_data.get("conv", {})
|
||||
# 构建分层的变量结构
|
||||
variables = {
|
||||
"sys": {
|
||||
@@ -110,7 +114,7 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=user_message)],
|
||||
"messages": [('user', user_message)],
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
@@ -196,6 +200,28 @@ class WorkflowExecutor:
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def _build_final_output(self, result, elapsed_time):
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
"variables": result.get("variables", {}),
|
||||
}
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -236,40 +262,16 @@ class WorkflowExecutor:
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state)
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# 提取节点输出(现在包含 start 和 end 节点)
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
|
||||
# 提取最终输出(从最后一个非 start/end 节点)
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
|
||||
# 聚合 token 使用情况
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
|
||||
# 提取 conversation_id(从 start 节点输出)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error")
|
||||
}
|
||||
return self._build_final_output(result, elapsed_time)
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
@@ -331,11 +333,11 @@ class WorkflowExecutor:
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
final_state = None
|
||||
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
config=self.checkpoint_config
|
||||
):
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
@@ -411,12 +413,11 @@ class WorkflowExecutor:
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||
final_state = data
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
|
||||
@@ -425,12 +426,7 @@ class WorkflowExecutor:
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"status": "completed",
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
"data": self._build_final_output(result, elapsed_time)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
@@ -249,4 +250,5 @@ class GraphBuilder:
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges() # 添加边必须在添加节点之后
|
||||
return self.graph.compile()
|
||||
checkpointer = InMemorySaver()
|
||||
return self.graph.compile(checkpointer=checkpointer)
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -28,6 +28,7 @@ class AssignerNode(BaseNode):
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
pool = VariablePool(state)
|
||||
for assignment in self.typed_config.assignments:
|
||||
|
||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
messages: Annotated[list[tuple[str, str]], add]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
@@ -355,6 +356,7 @@ class BaseNode(ABC):
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
|
||||
@@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
|
||||
@@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _build_timeout(self) -> Timeout:
|
||||
"""
|
||||
@@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode):
|
||||
- dict: Serialized HttpRequestNodeOutput on success
|
||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||
"""
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.typed_config.verify_ssl,
|
||||
timeout=self._build_timeout(),
|
||||
|
||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
self.typed_config: IfElseNodeConfig | None= None
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
@@ -109,6 +109,7 @@ class IfElseNode(BaseNode):
|
||||
Returns:
|
||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||
"""
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||
# TODO: 变量类型及文本类型解析
|
||||
for i in range(len(expressions)):
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode):
|
||||
RuntimeError: If Jinja2 template rendering fails due to invalid template
|
||||
syntax or missing variables.
|
||||
"""
|
||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||
render = TemplateRenderer(strict=False)
|
||||
|
||||
context = {}
|
||||
|
||||
@@ -44,8 +44,8 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
description="Knowledge base config"
|
||||
)
|
||||
|
||||
reranker_id: UUID = Field(
|
||||
default="",
|
||||
reranker_id: UUID | None = Field(
|
||||
default=None,
|
||||
description="Reranker top k"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
@@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
@@ -68,7 +68,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
@@ -164,6 +164,7 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
@@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService
|
||||
class MemoryReadNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
@@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
Raises:
|
||||
BusinessException: If LLM output cannot be parsed as valid JSON.
|
||||
"""
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
llm = self._get_llm_instance()
|
||||
system_prompt, user_prompt = self._get_prompt()
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
@@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
categories = self.typed_config.categories or []
|
||||
|
||||
@@ -7,6 +7,7 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
|
||||
@@ -34,7 +35,7 @@ class StartNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
@@ -47,6 +48,7 @@ class StartNode(BaseNode):
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
@@ -113,6 +115,18 @@ class StartNode(BaseNode):
|
||||
logger.debug(
|
||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||
)
|
||||
else:
|
||||
match var_def.type:
|
||||
case VariableType.STRING:
|
||||
processed[var_name] = ""
|
||||
case VariableType.NUMBER:
|
||||
processed[var_name] = 0
|
||||
case VariableType.OBJECT:
|
||||
processed[var_name] = {}
|
||||
case VariableType.BOOLEAN:
|
||||
processed[var_name] = False
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
processed[var_name] = []
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
@@ -19,10 +19,11 @@ class ToolNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
# 获取租户ID和用户ID
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_express(variable_string: str) -> Any:
|
||||
@@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
- str: In non-group mode, returns the first non-None variable value.
|
||||
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
|
||||
"""
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
if not self.typed_config.group:
|
||||
# --------------------------
|
||||
# Non-group mode
|
||||
|
||||
@@ -66,24 +66,38 @@ async def _update_activation_values_batch(
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# 提取节点ID列表
|
||||
node_ids = [node.get('id') for node in nodes if node.get('id')]
|
||||
# 提取节点ID列表并去重(保持原始顺序)
|
||||
seen_ids = set()
|
||||
unique_node_ids = []
|
||||
for node in nodes:
|
||||
node_id = node.get('id')
|
||||
if node_id and node_id not in seen_ids:
|
||||
seen_ids.add(node_id)
|
||||
unique_node_ids.append(node_id)
|
||||
|
||||
if not node_ids:
|
||||
if not unique_node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
|
||||
# 记录去重信息(仅针对具有有效 ID 的节点)
|
||||
id_nodes_count = sum(1 for n in nodes if n.get("id"))
|
||||
if len(unique_node_ids) < id_nodes_count:
|
||||
logger.info(
|
||||
f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, "
|
||||
f"去重后唯一ID数量={len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
node_ids=node_ids,
|
||||
node_ids=unique_node_ids,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
|
||||
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
return updated_nodes
|
||||
@@ -153,19 +167,38 @@ async def _update_search_results_activation(
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
# 创建 ID 到原始节点的映射(用于快速查找 score)
|
||||
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
|
||||
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
|
||||
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
|
||||
|
||||
# 合并数据:激活值来自更新结果,score 来自原始结果
|
||||
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
|
||||
merged_nodes = []
|
||||
for updated_node in updated_nodes:
|
||||
node_id = updated_node.get('id')
|
||||
if node_id and node_id in original_map:
|
||||
# 保留原始的 score 字段
|
||||
original_score = original_map[node_id].get('score')
|
||||
if original_score is not None:
|
||||
updated_node['score'] = original_score
|
||||
merged_nodes.append(updated_node)
|
||||
for original_node in original_nodes:
|
||||
node_id = original_node.get('id')
|
||||
if node_id and node_id in updated_map:
|
||||
# 从原始节点开始,用更新后的激活值数据覆盖
|
||||
merged_node = original_node.copy()
|
||||
|
||||
# 更新激活值相关字段
|
||||
activation_fields = {
|
||||
'activation_value',
|
||||
'access_history',
|
||||
'last_access_time',
|
||||
'access_count',
|
||||
'importance_score',
|
||||
'version',
|
||||
'statement', # Statement 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
}
|
||||
|
||||
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||
for field in activation_fields:
|
||||
if field in updated_map[node_id]:
|
||||
merged_node[field] = updated_map[node_id][field]
|
||||
|
||||
merged_nodes.append(merged_node)
|
||||
else:
|
||||
# 如果没有更新数据,保留原始节点
|
||||
merged_nodes.append(original_node)
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
|
||||
@@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -22,6 +23,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
@@ -101,6 +103,14 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
)
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
result = task_service.get_task_memory_read_result(task.id)
|
||||
status = result.get("status")
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
|
||||
@@ -491,6 +491,17 @@ class WorkflowService:
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
break
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -504,7 +515,7 @@ class WorkflowService:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
output_data=result
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
@@ -517,6 +528,7 @@ class WorkflowService:
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"variables": result.get("variables"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
@@ -617,6 +629,16 @@ class WorkflowService:
|
||||
original_user_id=payload.user_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
break
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in self._run_workflow_stream(
|
||||
@@ -827,6 +849,23 @@ class WorkflowService:
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user