perf(workflow): add tests, adapt some LLM node output formats, optimize sandbox return format
This commit is contained in:
@@ -18,6 +18,8 @@ from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
@@ -35,5 +37,7 @@ __all__ = [
|
||||
"JinjaRenderNode",
|
||||
"ParameterExtractorNode",
|
||||
"QuestionClassifierNode",
|
||||
"ToolNode"
|
||||
"ToolNode",
|
||||
"CodeNode",
|
||||
"VariableAggregatorNode"
|
||||
]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
__all__ = ["CodeNode", "CodeNodeConfig"]
|
||||
|
||||
@@ -216,7 +216,7 @@ class LLMNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
return AIMessage(content=content, response_metadata=response.response_metadata)
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
|
||||
@@ -193,7 +193,8 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
self.response_metadata = model_resp.response_metadata
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
model_message = self.process_model_output(model_resp.content)
|
||||
result = json_repair.repair_json(model_message, return_objects=True)
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -131,7 +131,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
result = response.content.strip()
|
||||
result = self.process_model_output(response.content)
|
||||
self.response_metadata = response.response_metadata
|
||||
|
||||
if result in category_names:
|
||||
|
||||
Reference in New Issue
Block a user