Merge #17 into develop from feature/20251219_myh
feat(workflow): add conditional branch (If-Else) node * feature/20251219_myh: (10 commits) fix(workflow): fix run_workflow streaming issues fix(prompt-optimizer): switch to built-in system prompt feat(workflow): add conditional branch (If-Else) node perf(types): add Union type declaration for workflow nodes fix(expression-eval): fix variable extraction issue in Jinja2 templates docs(samples): add config example for If-Else node style(workflow): update condition edge comments for conditional nodes style(enums): correct enum class name spelling refactor(workflow): unify all enum classes in one file and restructure workflow... feat(workflow): add import for if-else node configuration Signed-off-by: Eternity <1533512157@qq.com> Commented-by: Eternity <1533512157@qq.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/17
This commit is contained in:
@@ -117,7 +117,7 @@ async def get_prompt_opt(
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
current_prompt=data.current_prompt,
|
current_prompt=data.current_prompt,
|
||||||
message=data.message
|
user_require=data.message
|
||||||
)
|
)
|
||||||
service.create_message(
|
service.create_message(
|
||||||
tenant_id=current_user.tenant_id,
|
tenant_id=current_user.tenant_id,
|
||||||
@@ -136,35 +136,3 @@ async def get_prompt_opt(
|
|||||||
return success(data=result_schema)
|
return success(data=result_schema)
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
|
||||||
"/model",
|
|
||||||
summary="Create or update prompt model config",
|
|
||||||
response_model=ApiResponse
|
|
||||||
)
|
|
||||||
def set_system_prompt(
|
|
||||||
data: PromptOptModelSet = ...,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create or update a system prompt model configuration for the tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data (PromptOptModelSet): Model configuration data including model ID,
|
|
||||||
system prompt, and optional configuration ID
|
|
||||||
db (Session): Database session
|
|
||||||
current_user: Current user information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UUID: The ID of the created or updated model configuration.
|
|
||||||
"""
|
|
||||||
if data.id is None:
|
|
||||||
data.id = uuid.uuid4()
|
|
||||||
|
|
||||||
model_config = PromptOptimizerService(db).create_update_model_config(
|
|
||||||
current_user.tenant_id,
|
|
||||||
data.id,
|
|
||||||
data.system_prompt
|
|
||||||
)
|
|
||||||
return success(data=model_config.id)
|
|
||||||
|
|
||||||
|
|||||||
@@ -473,7 +473,7 @@ async def run_workflow(
|
|||||||
async def event_generator():
|
async def event_generator():
|
||||||
"""生成 SSE 事件"""
|
"""生成 SSE 事件"""
|
||||||
try:
|
try:
|
||||||
async for event in service.run_workflow(
|
async for event in await service.run_workflow(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
triggered_by=current_user.id,
|
triggered_by=current_user.id,
|
||||||
|
|||||||
@@ -4,16 +4,17 @@
|
|||||||
基于 LangGraph 的工作流执行引擎。
|
基于 LangGraph 的工作流执行引擎。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,11 +26,11 @@ class WorkflowExecutor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str
|
||||||
):
|
):
|
||||||
"""初始化执行器
|
"""初始化执行器
|
||||||
|
|
||||||
@@ -90,8 +91,6 @@ class WorkflowExecutor:
|
|||||||
"error_node": None
|
"error_node": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_graph(self) -> CompiledStateGraph:
|
def build_graph(self) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
|
|
||||||
@@ -112,19 +111,38 @@ class WorkflowExecutor:
|
|||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
|
|
||||||
# 记录 start 和 end 节点 ID
|
# 记录 start 和 end 节点 ID
|
||||||
if node_type == "start":
|
if node_type == NodeType.START:
|
||||||
start_node_id = node_id
|
start_node_id = node_id
|
||||||
elif node_type == "end":
|
elif node_type == NodeType.END:
|
||||||
end_node_ids.append(node_id)
|
end_node_ids.append(node_id)
|
||||||
|
|
||||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
|
|
||||||
|
if node_type in [NodeType.IF_ELSE]:
|
||||||
|
expressions = node_instance.build_conditional_edge_expressions()
|
||||||
|
|
||||||
|
# Number of branches, usually matches the number of conditional expressions
|
||||||
|
branch_number = len(expressions)
|
||||||
|
|
||||||
|
# Find all edges whose source is the current node
|
||||||
|
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||||
|
|
||||||
|
# Iterate over each branch
|
||||||
|
for idx in range(branch_number):
|
||||||
|
# Generate a condition expression for each edge
|
||||||
|
# Used later to determine which branch to take based on the node's output
|
||||||
|
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||||
|
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||||
|
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# 包装节点的 run 方法
|
# 包装节点的 run 方法
|
||||||
# 使用函数工厂避免闭包问题
|
# 使用函数工厂避免闭包问题
|
||||||
def make_node_func(inst):
|
def make_node_func(inst):
|
||||||
async def node_func(state: WorkflowState):
|
async def node_func(state: WorkflowState):
|
||||||
return await inst.run(state)
|
return await inst.run(state)
|
||||||
|
|
||||||
return node_func
|
return node_func
|
||||||
|
|
||||||
workflow.add_node(node_id, make_node_func(node_instance))
|
workflow.add_node(node_id, make_node_func(node_instance))
|
||||||
@@ -165,14 +183,14 @@ class WorkflowExecutor:
|
|||||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||||
"""条件路由函数"""
|
"""条件路由函数"""
|
||||||
if evaluate_condition(
|
if evaluate_condition(
|
||||||
cond,
|
cond,
|
||||||
state.get("variables", {}),
|
state.get("variables", {}),
|
||||||
state.get("node_outputs", {}),
|
state.get("node_outputs", {}),
|
||||||
{
|
{
|
||||||
"execution_id": state.get("execution_id"),
|
"execution_id": state.get("execution_id"),
|
||||||
"workspace_id": state.get("workspace_id"),
|
"workspace_id": state.get("workspace_id"),
|
||||||
"user_id": state.get("user_id")
|
"user_id": state.get("user_id")
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
return tgt
|
return tgt
|
||||||
return END # 条件不满足,结束
|
return END # 条件不满足,结束
|
||||||
@@ -196,8 +214,8 @@ class WorkflowExecutor:
|
|||||||
return graph
|
return graph
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
input_data: dict[str, Any]
|
input_data: dict[str, Any]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""执行工作流(非流式)
|
"""执行工作流(非流式)
|
||||||
|
|
||||||
@@ -271,8 +289,8 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
input_data: dict[str, Any]
|
input_data: dict[str, Any]
|
||||||
):
|
):
|
||||||
"""执行工作流(流式)
|
"""执行工作流(流式)
|
||||||
|
|
||||||
@@ -305,7 +323,7 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
async for chunk in graph.astream(
|
async for chunk in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
# subgraphs=True,
|
# subgraphs=True,
|
||||||
stream_mode="updates",
|
stream_mode="updates",
|
||||||
):
|
):
|
||||||
# print(chunk)
|
# print(chunk)
|
||||||
@@ -326,7 +344,6 @@ class WorkflowExecutor:
|
|||||||
"token_usage": None
|
"token_usage": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||||
"""从节点输出中提取最终输出
|
"""从节点输出中提取最终输出
|
||||||
|
|
||||||
@@ -386,11 +403,11 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
|
|
||||||
async def execute_workflow(
|
async def execute_workflow(
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""执行工作流(便捷函数)
|
"""执行工作流(便捷函数)
|
||||||
|
|
||||||
@@ -414,11 +431,11 @@ async def execute_workflow(
|
|||||||
|
|
||||||
|
|
||||||
async def execute_workflow_stream(
|
async def execute_workflow_stream(
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str
|
||||||
):
|
):
|
||||||
"""执行工作流(流式,便捷函数)
|
"""执行工作流(流式,便捷函数)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||||
@@ -59,9 +60,10 @@ class ExpressionEvaluator:
|
|||||||
"""
|
"""
|
||||||
# 移除 Jinja2 模板语法的花括号(如果存在)
|
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||||
expression = expression.strip()
|
expression = expression.strip()
|
||||||
if expression.startswith("{{") and expression.endswith("}}"):
|
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
|
||||||
expression = expression[2:-2].strip()
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
expression = re.sub(pattern, r"\1", expression).strip()
|
||||||
|
|
||||||
# 构建命名空间上下文
|
# 构建命名空间上下文
|
||||||
context = {
|
context = {
|
||||||
"var": variables, # 用户变量
|
"var": variables, # 用户变量
|
||||||
|
|||||||
@@ -4,13 +4,14 @@
|
|||||||
提供各种类型的节点实现,用于工作流执行。
|
提供各种类型的节点实现,用于工作流执行。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
|
||||||
from app.core.workflow.nodes.llm import LLMNode
|
|
||||||
from app.core.workflow.nodes.agent import AgentNode
|
from app.core.workflow.nodes.agent import AgentNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.start import StartNode
|
|
||||||
from app.core.workflow.nodes.end import EndNode
|
from app.core.workflow.nodes.end import EndNode
|
||||||
from app.core.workflow.nodes.node_factory import NodeFactory
|
from app.core.workflow.nodes.if_else import IfElseNode
|
||||||
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
|
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||||
|
from app.core.workflow.nodes.start import StartNode
|
||||||
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseNode",
|
"BaseNode",
|
||||||
@@ -18,7 +19,9 @@ __all__ = [
|
|||||||
"LLMNode",
|
"LLMNode",
|
||||||
"AgentNode",
|
"AgentNode",
|
||||||
"TransformNode",
|
"TransformNode",
|
||||||
|
"IfElseNode",
|
||||||
"StartNode",
|
"StartNode",
|
||||||
"EndNode",
|
"EndNode",
|
||||||
"NodeFactory",
|
"NodeFactory",
|
||||||
|
"WorkflowNode"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from app.core.workflow.nodes.end.config import EndNodeConfig
|
|||||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||||
|
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础类
|
# 基础类
|
||||||
@@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"MessageConfig",
|
"MessageConfig",
|
||||||
"AgentNodeConfig",
|
"AgentNodeConfig",
|
||||||
"TransformNodeConfig",
|
"TransformNodeConfig",
|
||||||
|
"IfElseNodeConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
class NodeType(StrEnum):
|
class NodeType(StrEnum):
|
||||||
START = "start"
|
START = "start"
|
||||||
END = "end"
|
END = "end"
|
||||||
@@ -13,3 +14,23 @@ class NodeType(StrEnum):
|
|||||||
HTTP_REQUEST = "http-request"
|
HTTP_REQUEST = "http-request"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
AGENT = "agent"
|
AGENT = "agent"
|
||||||
|
|
||||||
|
|
||||||
|
class ComparisonOperator(StrEnum):
|
||||||
|
EMPTY = "empty"
|
||||||
|
NOT_EMPTY = "not_empty"
|
||||||
|
CONTAINS = "contains"
|
||||||
|
NOT_CONTAINS = "not_contains"
|
||||||
|
START_WITH = "startwith"
|
||||||
|
END_WITH = "endwith"
|
||||||
|
EQ = "eq"
|
||||||
|
NE = "ne"
|
||||||
|
LT = "lt"
|
||||||
|
LE = "le"
|
||||||
|
GT = "gt"
|
||||||
|
GE = "ge"
|
||||||
|
|
||||||
|
|
||||||
|
class LogicOperator(StrEnum):
|
||||||
|
AND = "and"
|
||||||
|
OR = "or"
|
||||||
|
|||||||
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Condition Node"""
|
||||||
|
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||||
|
from app.core.workflow.nodes.if_else.node import IfElseNode
|
||||||
|
|
||||||
|
__all__ = ["IfElseNode", "IfElseNodeConfig"]
|
||||||
97
api/app/core/workflow/nodes/if_else/config.py
Normal file
97
api/app/core/workflow/nodes/if_else/config.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Condition Configuration"""
|
||||||
|
from pydantic import Field, BaseModel, field_validator
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionDetail(BaseModel):
|
||||||
|
comparison_operator: ComparisonOperator = Field(
|
||||||
|
...,
|
||||||
|
description="Comparison operator used to evaluate the condition"
|
||||||
|
)
|
||||||
|
|
||||||
|
left: str = Field(
|
||||||
|
...,
|
||||||
|
description="Value to compare against"
|
||||||
|
)
|
||||||
|
|
||||||
|
right: str = Field(
|
||||||
|
...,
|
||||||
|
description="Value to compare with"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionBranchConfig(BaseModel):
|
||||||
|
"""Configuration for a conditional branch"""
|
||||||
|
|
||||||
|
logical_operator: LogicOperator = Field(
|
||||||
|
default=LogicOperator.AND.value,
|
||||||
|
description="Logical operator used to combine multiple condition expressions"
|
||||||
|
)
|
||||||
|
|
||||||
|
conditions: list[ConditionDetail] = Field(
|
||||||
|
...,
|
||||||
|
description="List of condition expressions within this branch"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IfElseNodeConfig(BaseNodeConfig):
|
||||||
|
cases: list[ConditionBranchConfig] = Field(
|
||||||
|
...,
|
||||||
|
description="List of branch conditions or expressions"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("cases")
|
||||||
|
@classmethod
|
||||||
|
def validate_case_number(cls, v, info):
|
||||||
|
if len(v) < 1:
|
||||||
|
raise ValueError("At least one cases are required")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"cases": [
|
||||||
|
# CASE1 / IF Branch
|
||||||
|
{
|
||||||
|
"logical_operator": "and",
|
||||||
|
"conditions": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"left": "node.userinput.message",
|
||||||
|
"comparison_operator": "eq",
|
||||||
|
"right": "'123'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"left": "node.userinput.test",
|
||||||
|
"comparison_operator": "eq",
|
||||||
|
"right": "True"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# CASE1 / ELIF Branch
|
||||||
|
{
|
||||||
|
"logical_operator": "or",
|
||||||
|
"conditions": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"left": "node.userinput.test",
|
||||||
|
"comparison_operator": "eq",
|
||||||
|
"right": "False"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"left": "node.userinput.message",
|
||||||
|
"comparison_operator": "contains",
|
||||||
|
"right": "'123'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# CASE3 / ELSE Branch
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
167
api/app/core/workflow/nodes/if_else/node.py
Normal file
167
api/app/core/workflow/nodes/if_else/node.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||||
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
|
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionExpressionBuilder:
|
||||||
|
"""
|
||||||
|
Build a Python boolean expression string based on a comparison operator.
|
||||||
|
|
||||||
|
This class does not evaluate the expression.
|
||||||
|
It only generates a valid Python expression string
|
||||||
|
that can be evaluated later in a workflow context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
||||||
|
self.left = left
|
||||||
|
self.operator = operator
|
||||||
|
self.right = right
|
||||||
|
|
||||||
|
def _empty(self):
|
||||||
|
return f"{self.left} == ''"
|
||||||
|
|
||||||
|
def _not_empty(self):
|
||||||
|
return f"{self.left} != ''"
|
||||||
|
|
||||||
|
def _contains(self):
|
||||||
|
return f"{self.right} in {self.left}"
|
||||||
|
|
||||||
|
def _not_contains(self):
|
||||||
|
return f"{self.right} not in {self.left}"
|
||||||
|
|
||||||
|
def _startwith(self):
|
||||||
|
return f'{self.left}.startswith({self.right})'
|
||||||
|
|
||||||
|
def _endwith(self):
|
||||||
|
return f'{self.left}.endswith({self.right})'
|
||||||
|
|
||||||
|
def _eq(self):
|
||||||
|
return f"{self.left} == {self.right}"
|
||||||
|
|
||||||
|
def _ne(self):
|
||||||
|
return f"{self.left} != {self.right}"
|
||||||
|
|
||||||
|
def _lt(self):
|
||||||
|
return f"{self.left} < {self.right}"
|
||||||
|
|
||||||
|
def _le(self):
|
||||||
|
return f"{self.left} <= {self.right}"
|
||||||
|
|
||||||
|
def _gt(self):
|
||||||
|
return f"{self.left} > {self.right}"
|
||||||
|
|
||||||
|
def _ge(self):
|
||||||
|
return f"{self.left} >= {self.right}"
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
match self.operator:
|
||||||
|
case ComparisonOperator.EMPTY:
|
||||||
|
return self._empty()
|
||||||
|
case ComparisonOperator.NOT_EMPTY:
|
||||||
|
return self._not_empty()
|
||||||
|
case ComparisonOperator.CONTAINS:
|
||||||
|
return self._contains()
|
||||||
|
case ComparisonOperator.NOT_CONTAINS:
|
||||||
|
return self._not_contains()
|
||||||
|
case ComparisonOperator.START_WITH:
|
||||||
|
return self._startwith()
|
||||||
|
case ComparisonOperator.END_WITH:
|
||||||
|
return self._endwith()
|
||||||
|
case ComparisonOperator.EQ:
|
||||||
|
return self._eq()
|
||||||
|
case ComparisonOperator.NE:
|
||||||
|
return self._ne()
|
||||||
|
case ComparisonOperator.LT:
|
||||||
|
return self._lt()
|
||||||
|
case ComparisonOperator.LE:
|
||||||
|
return self._le()
|
||||||
|
case ComparisonOperator.GT:
|
||||||
|
return self._gt()
|
||||||
|
case ComparisonOperator.GE:
|
||||||
|
return self._ge()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {self.operator}")
|
||||||
|
|
||||||
|
|
||||||
|
class IfElseNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = IfElseNodeConfig(**self.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_condition_expression(
|
||||||
|
condition: ConditionDetail,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a single boolean condition expression string.
|
||||||
|
|
||||||
|
This method does NOT evaluate the condition.
|
||||||
|
It only generates a valid Python boolean expression string
|
||||||
|
(e.g. "x > 10", "'a' in name") that can later be used
|
||||||
|
in a conditional edge or evaluated by the workflow engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition (ConditionDetail): Definition of a single comparison condition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A Python boolean expression string.
|
||||||
|
"""
|
||||||
|
return ConditionExpressionBuilder(
|
||||||
|
left=condition.left,
|
||||||
|
operator=condition.comparison_operator,
|
||||||
|
right=condition.right
|
||||||
|
).build()
|
||||||
|
|
||||||
|
def build_conditional_edge_expressions(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Build conditional edge expressions for the If-Else node.
|
||||||
|
|
||||||
|
This method does NOT evaluate any condition at runtime.
|
||||||
|
Instead, it converts each case branch into a Python boolean
|
||||||
|
expression string, which will later be attached to LangGraph
|
||||||
|
as conditional edges.
|
||||||
|
|
||||||
|
Each returned expression corresponds to one branch and is
|
||||||
|
evaluated in order. A fallback 'True' condition is appended
|
||||||
|
to ensure a default branch when no previous conditions match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of Python boolean expression strings,
|
||||||
|
ordered by branch priority.
|
||||||
|
"""
|
||||||
|
branch_index = 0
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
for case_branch in self.typed_config.cases:
|
||||||
|
branch_index += 1
|
||||||
|
|
||||||
|
branch_conditions = [
|
||||||
|
self._build_condition_expression(condition)
|
||||||
|
for condition in case_branch.conditions
|
||||||
|
]
|
||||||
|
if len(branch_conditions) > 1:
|
||||||
|
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
||||||
|
else:
|
||||||
|
combined_condition = branch_conditions[0]
|
||||||
|
conditions.append(combined_condition)
|
||||||
|
|
||||||
|
# Default fallback branch
|
||||||
|
conditions.append("True")
|
||||||
|
|
||||||
|
return conditions
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
expressions = self.build_conditional_edge_expressions()
|
||||||
|
for i in range(len(expressions)):
|
||||||
|
logger.info(expressions[i])
|
||||||
|
if self._evaluate_condition(expressions[i], state):
|
||||||
|
return f'CASE{i+1}'
|
||||||
|
return f'CASE{len(expressions)}'
|
||||||
@@ -5,18 +5,29 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
|
||||||
from app.core.workflow.nodes.llm import LLMNode
|
|
||||||
from app.core.workflow.nodes.agent import AgentNode
|
from app.core.workflow.nodes.agent import AgentNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.start import StartNode
|
|
||||||
from app.core.workflow.nodes.end import EndNode
|
from app.core.workflow.nodes.end import EndNode
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
from app.core.workflow.nodes.if_else import IfElseNode
|
||||||
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
|
from app.core.workflow.nodes.start import StartNode
|
||||||
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
WorkflowNode = Union[
|
||||||
|
BaseNode,
|
||||||
|
StartNode,
|
||||||
|
EndNode,
|
||||||
|
LLMNode,
|
||||||
|
IfElseNode,
|
||||||
|
AgentNode,
|
||||||
|
TransformNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class NodeFactory:
|
class NodeFactory:
|
||||||
"""节点工厂
|
"""节点工厂
|
||||||
@@ -25,16 +36,17 @@ class NodeFactory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 节点类型注册表
|
# 节点类型注册表
|
||||||
_node_types: dict[str, type[BaseNode]] = {
|
_node_types: dict[str, type[WorkflowNode]] = {
|
||||||
NodeType.START: StartNode,
|
NodeType.START: StartNode,
|
||||||
NodeType.END: EndNode,
|
NodeType.END: EndNode,
|
||||||
NodeType.LLM: LLMNode,
|
NodeType.LLM: LLMNode,
|
||||||
NodeType.AGENT: AgentNode,
|
NodeType.AGENT: AgentNode,
|
||||||
NodeType.TRANSFORM: TransformNode,
|
NodeType.TRANSFORM: TransformNode,
|
||||||
|
NodeType.IF_ELSE: IfElseNode
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
|
def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]):
|
||||||
"""注册新的节点类型
|
"""注册新的节点类型
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -52,10 +64,10 @@ class NodeFactory:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_node(
|
def create_node(
|
||||||
cls,
|
cls,
|
||||||
node_config: dict[str, Any],
|
node_config: dict[str, Any],
|
||||||
workflow_config: dict[str, Any]
|
workflow_config: dict[str, Any]
|
||||||
) -> BaseNode | None:
|
) -> WorkflowNode | None:
|
||||||
"""创建节点实例
|
"""创建节点实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from .data_config_model import DataConfig
|
|||||||
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
||||||
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||||
from .retrieval_info import RetrievalInfo
|
from .retrieval_info import RetrievalInfo
|
||||||
from .prompt_optimizer_model import PromptOptimizerModelConfig, PromptOptimizerSession, PromptOptimizerSessionHistory
|
from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Tenants",
|
"Tenants",
|
||||||
@@ -56,7 +56,6 @@ __all__ = [
|
|||||||
"WorkflowExecution",
|
"WorkflowExecution",
|
||||||
"WorkflowNodeExecution",
|
"WorkflowNodeExecution",
|
||||||
"RetrievalInfo",
|
"RetrievalInfo",
|
||||||
"PromptOptimizerModelConfig",
|
|
||||||
"PromptOptimizerSession",
|
"PromptOptimizerSession",
|
||||||
"PromptOptimizerSessionHistory"
|
"PromptOptimizerSessionHistory"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -27,49 +27,6 @@ class RoleType(StrEnum):
|
|||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
|
|
||||||
|
|
||||||
class PromptOptimizerModelConfig(Base):
|
|
||||||
"""
|
|
||||||
Prompt Optimization Model Configuration.
|
|
||||||
|
|
||||||
This table stores system-level prompt configurations for each tenant.
|
|
||||||
The configuration defines the base system prompt used during prompt
|
|
||||||
optimization sessions and serves as a foundational instruction set
|
|
||||||
for the optimization process.
|
|
||||||
|
|
||||||
Each tenant may have one or more model configurations depending on
|
|
||||||
business requirements.
|
|
||||||
|
|
||||||
Table Name:
|
|
||||||
prompt_model_config
|
|
||||||
|
|
||||||
Columns:
|
|
||||||
id (UUID):
|
|
||||||
Primary key. Unique identifier for the prompt model configuration.
|
|
||||||
tenant_id (UUID):
|
|
||||||
Foreign key referencing `tenants.id`.
|
|
||||||
Identifies the tenant that owns this configuration.
|
|
||||||
system_prompt (Text):
|
|
||||||
The system-level prompt used to guide prompt optimization logic.
|
|
||||||
created_at (DateTime):
|
|
||||||
Timestamp indicating when the configuration was created.
|
|
||||||
updated_at (DateTime):
|
|
||||||
Timestamp indicating the last update time of the configuration.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
- Loaded when initializing a prompt optimization session
|
|
||||||
- Acts as the root system instruction for all subsequent prompts
|
|
||||||
"""
|
|
||||||
__tablename__ = "prompt_model_config"
|
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
|
||||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
|
||||||
# model_id = Column(UUID(as_uuid=True), nullable=False, comment="Model ID")
|
|
||||||
system_prompt = Column(Text, nullable=False, comment="System Prompt")
|
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time")
|
|
||||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="Update Time")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptOptimizerSession(Base):
|
class PromptOptimizerSession(Base):
|
||||||
"""
|
"""
|
||||||
Prompt Optimization Session Registry.
|
Prompt Optimization Session Registry.
|
||||||
|
|||||||
@@ -1,120 +1,15 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models.prompt_optimizer_model import (
|
from app.models.prompt_optimizer_model import (
|
||||||
PromptOptimizerModelConfig,
|
|
||||||
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
|
|
||||||
|
|
||||||
class PromptOptimizerModelConfigRepository:
|
|
||||||
"""Repository for managing prompt optimizer model configurations."""
|
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def get_by_tenant_id(self, tenant_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]:
|
|
||||||
"""
|
|
||||||
Retrieve the prompt optimizer model configuration for a specific tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[PromptOptimizerModelConfig]: The model configuration if found, else None.
|
|
||||||
"""
|
|
||||||
db_logger.debug(f"Get prompt optimization model configuration: tenant_id={tenant_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = self.db.query(PromptOptimizerModelConfig).filter(
|
|
||||||
PromptOptimizerModelConfig.tenant_id == tenant_id,
|
|
||||||
# PromptOptimizerModelConfig.model_id == model_id
|
|
||||||
).first()
|
|
||||||
if config:
|
|
||||||
db_logger.debug(f"Prompt optimization model configuration found: (ID: {config.id})")
|
|
||||||
else:
|
|
||||||
db_logger.debug(f"Prompt optimization model configuration not found: tenant_id={tenant_id}")
|
|
||||||
return config
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(
|
|
||||||
f"Error retrieving prompt optimization model configuration: tenant_id={tenant_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_config_id(self, tenant_id: uuid.UUID, config_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]:
|
|
||||||
"""
|
|
||||||
Retrieve a specific prompt optimizer model configuration by config ID and tenant ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
|
||||||
config_id (uuid.UUID): The unique identifier of the model configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[PromptOptimizerModelConfig]: The model configuration if found, else None.
|
|
||||||
"""
|
|
||||||
db_logger.debug(f"Get prompt optimization model configuration: config_id={config_id}, tenant_id={tenant_id}")
|
|
||||||
try:
|
|
||||||
model = self.db.query(PromptOptimizerModelConfig).filter(
|
|
||||||
PromptOptimizerModelConfig.tenant_id == tenant_id,
|
|
||||||
PromptOptimizerModelConfig.id == config_id
|
|
||||||
).first()
|
|
||||||
if model:
|
|
||||||
db_logger.debug(f"Prompt optimization model configuration found: (ID: {model.id})")
|
|
||||||
else:
|
|
||||||
db_logger.debug(f"Prompt optimization model configuration not found: config_id={config_id}")
|
|
||||||
return model
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(
|
|
||||||
f"Error retrieving prompt optimization model configuration: model_id={config_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def create_or_update(
|
|
||||||
self,
|
|
||||||
config_id: uuid.UUID,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
system_prompt: str,
|
|
||||||
) -> Optional[PromptOptimizerModelConfig]:
|
|
||||||
"""
|
|
||||||
Create a new or update an existing prompt optimizer model configuration.
|
|
||||||
|
|
||||||
If a configuration with the given config_id exists, it updates its system_prompt.
|
|
||||||
Otherwise, it creates a new configuration record.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_id (uuid.UUID): The unique identifier for the configuration.
|
|
||||||
tenant_id (uuid.UUID): The tenant's unique identifier.
|
|
||||||
system_prompt (str): The system prompt content for prompt optimization.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[PromptOptimizerModelConfig]: The created or updated model configuration.
|
|
||||||
"""
|
|
||||||
db_logger.debug(f"Create/Update prompt optimization model configuration: tenant_id={tenant_id}")
|
|
||||||
existing_config = self.get_by_config_id(tenant_id, config_id)
|
|
||||||
|
|
||||||
if existing_config:
|
|
||||||
existing_config.system_prompt = system_prompt
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(existing_config)
|
|
||||||
db_logger.debug(f"Prompt optimization model configuration update: ID:{config_id}")
|
|
||||||
return existing_config
|
|
||||||
else:
|
|
||||||
config = PromptOptimizerModelConfig(
|
|
||||||
id=config_id,
|
|
||||||
# model_id=model_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
system_prompt=system_prompt
|
|
||||||
)
|
|
||||||
self.db.add(config)
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(config)
|
|
||||||
db_logger.debug(f"Prompt optimization model configuration created: ID:{config.id}")
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class PromptOptimizerSessionRepository:
|
class PromptOptimizerSessionRepository:
|
||||||
"""Repository for managing prompt optimization sessions and session history."""
|
"""Repository for managing prompt optimization sessions and session history."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@@ -12,13 +11,11 @@ from app.core.models import RedBearModelConfig
|
|||||||
from app.core.models.llm import RedBearLLM
|
from app.core.models.llm import RedBearLLM
|
||||||
from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory
|
from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory
|
||||||
from app.models.prompt_optimizer_model import (
|
from app.models.prompt_optimizer_model import (
|
||||||
PromptOptimizerModelConfig,
|
|
||||||
PromptOptimizerSession,
|
PromptOptimizerSession,
|
||||||
RoleType
|
RoleType
|
||||||
)
|
)
|
||||||
from app.repositories.model_repository import ModelConfigRepository
|
from app.repositories.model_repository import ModelConfigRepository
|
||||||
from app.repositories.prompt_optimizer_repository import (
|
from app.repositories.prompt_optimizer_repository import (
|
||||||
PromptOptimizerModelConfigRepository,
|
|
||||||
PromptOptimizerSessionRepository
|
PromptOptimizerSessionRepository
|
||||||
)
|
)
|
||||||
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
||||||
@@ -34,32 +31,24 @@ class PromptOptimizerService:
|
|||||||
self,
|
self,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
model_id: uuid.UUID
|
model_id: uuid.UUID
|
||||||
) -> tuple[PromptOptimizerModelConfig, ModelConfig]:
|
) -> ModelConfig:
|
||||||
"""
|
"""
|
||||||
Retrieve the prompt optimizer model configuration and model configuration.
|
Retrieve the model configuration for a specific tenant.
|
||||||
|
|
||||||
This method retrieves the prompt optimizer model configuration associated
|
This method fetches the model configuration associated with the given
|
||||||
with the specified model ID and tenant. It also fetches the corresponding
|
tenant_id and model_id. If no configuration is found, a BusinessException
|
||||||
model configuration.
|
is raised.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
||||||
model_id (uuid.UUID): The unique identifier of the prompt optimization model.
|
model_id (uuid.UUID): The unique identifier of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[PromptOptimzerModelConfig, ModelConfig]:
|
ModelConfig: The corresponding model configuration object.
|
||||||
A tuple containing the prompt optimizer model configuration
|
|
||||||
and the corresponding model configuration.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
BusinessException: If the prompt optimizer model configuration does not exist.
|
|
||||||
BusinessException: If the model configuration does not exist.
|
BusinessException: If the model configuration does not exist.
|
||||||
"""
|
"""
|
||||||
prompt_config = PromptOptimizerModelConfigRepository(self.db).get_by_tenant_id(
|
|
||||||
tenant_id
|
|
||||||
)
|
|
||||||
if not prompt_config:
|
|
||||||
raise BusinessException("提示词模型配置不存在", BizCode.NOT_FOUND)
|
|
||||||
|
|
||||||
model = ModelConfigRepository.get_by_id(
|
model = ModelConfigRepository.get_by_id(
|
||||||
self.db, model_id, tenant_id=tenant_id
|
self.db, model_id, tenant_id=tenant_id
|
||||||
@@ -67,35 +56,7 @@ class PromptOptimizerService:
|
|||||||
if not model:
|
if not model:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
return prompt_config, model
|
return model
|
||||||
|
|
||||||
def create_update_model_config(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
config_id: uuid.UUID,
|
|
||||||
system_prompt: str,
|
|
||||||
) -> PromptOptimizerModelConfig:
|
|
||||||
"""
|
|
||||||
Create or update a prompt optimizer model configuration.
|
|
||||||
|
|
||||||
This method creates a new prompt optimizer model configuration or updates
|
|
||||||
an existing one identified by the given configuration ID. The configuration
|
|
||||||
defines the system prompt used for prompt optimization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
|
||||||
config_id (uuid.UUID): The unique identifier of the configuration to create or update.
|
|
||||||
system_prompt (str): The system prompt content used for prompt optimization.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PromptOptimzerModelConfig: The created or updated prompt optimizer model configuration.
|
|
||||||
"""
|
|
||||||
prompt_config = PromptOptimizerModelConfigRepository(self.db).create_or_update(
|
|
||||||
config_id=config_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
)
|
|
||||||
return prompt_config
|
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
@@ -159,37 +120,46 @@ class PromptOptimizerService:
|
|||||||
session_id: uuid.UUID,
|
session_id: uuid.UUID,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
current_prompt: str,
|
current_prompt: str,
|
||||||
message: str
|
user_require: str
|
||||||
) -> OptimizePromptResult:
|
) -> OptimizePromptResult:
|
||||||
"""
|
"""
|
||||||
Optimize a prompt using a prompt optimizer LLM.
|
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||||
|
|
||||||
This method uses a configured prompt optimizer model to refine an existing
|
This method refines the original prompt according to the user's requirements,
|
||||||
prompt based on the user's requirements. The optimized prompt is generated
|
generating an optimized version that is directly usable by AI tools. The
|
||||||
according to predefined system rules, including Jinja2 variable syntax and
|
optimization process follows strict rules, including:
|
||||||
a strict JSON output format.
|
- Wrapping user-inserted variables in double curly braces {{}}.
|
||||||
|
- Adhering to Jinja2 variable syntax if applicable.
|
||||||
|
- Ensuring a clear logic flow, explicit instructions, and strong executability.
|
||||||
|
- Producing output in a strict JSON format.
|
||||||
|
|
||||||
|
Steps performed:
|
||||||
|
1. Retrieve the model configuration for the given tenant and model.
|
||||||
|
2. Fetch the session message history for context.
|
||||||
|
3. Instantiate the LLM with the appropriate API key and model configuration.
|
||||||
|
4. Build system messages outlining optimization rules.
|
||||||
|
5. Format the user's original prompt and requirements as a user message.
|
||||||
|
6. Send messages to the LLM to generate the optimized prompt.
|
||||||
|
7. Generate a concise description summarizing the changes made during optimization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
tenant_id (uuid.UUID): Tenant identifier.
|
||||||
model_id (uuid.UUID): The unique identifier of the prompt optimizer model.
|
model_id (uuid.UUID): Prompt optimizer model identifier.
|
||||||
session_id (uuid.UUID): The unique identifier of the prompt optimization session.
|
session_id (uuid.UUID): Prompt optimization session identifier.
|
||||||
user_id (uuid.UUID): The unique identifier of the user associated with the session.
|
user_id (uuid.UUID): Identifier of the user associated with the session.
|
||||||
current_prompt (str): The original prompt to be optimized.
|
current_prompt (str): Original prompt to optimize.
|
||||||
message (str): The user's requirements or modification instructions.
|
user_require (str): User's requirements or instructions for optimization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the optimized prompt and the description
|
OptimizePromptResult: An object containing:
|
||||||
of changes, in the following format:
|
- prompt: The optimized prompt string.
|
||||||
{
|
- desc: A short description summarizing the changes.
|
||||||
"prompt": "<optimized_prompt>",
|
|
||||||
"desc": "<change_description>"
|
|
||||||
}
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
BusinessException: If the model response cannot be parsed as valid JSON
|
BusinessException: If the LLM response cannot be parsed as valid JSON
|
||||||
or does not conform to the expected output format.
|
or does not conform to the expected output format.
|
||||||
"""
|
"""
|
||||||
prompt_config, model_config = self.get_model_config(tenant_id, model_id)
|
model_config = self.get_model_config(tenant_id, model_id)
|
||||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
# Create LLM instance
|
# Create LLM instance
|
||||||
@@ -204,36 +174,65 @@ class PromptOptimizerService:
|
|||||||
# build message
|
# build message
|
||||||
messages = [
|
messages = [
|
||||||
# init system_prompt
|
# init system_prompt
|
||||||
(RoleType.SYSTEM.value, prompt_config.system_prompt),
|
(
|
||||||
|
RoleType.SYSTEM.value,
|
||||||
|
"Your task is to optimize the original prompt provided by the user so that it can be directly used by AI tools,"
|
||||||
|
"and the variables that the user needs to insert must be wrapped in {{}}. "
|
||||||
|
"The optimized prompt should align with the optimization direction specified by the user (if any) and ensure clear logic, explicit instructions, and strong executability. "
|
||||||
|
"Please follow these rules when optimizing: "
|
||||||
|
'1. Ensure variables are wrapped in {{}}, e.g., optimize "Please enter your question" to "Please enter your {{question}}"'
|
||||||
|
"2. Instructions must be specific and operable, avoiding vague expressions"
|
||||||
|
"3. If the original prompt lacks key elements (such as output format requirements), supplement them completely "
|
||||||
|
"4. Keep the language concise and avoid redundancy "
|
||||||
|
"5. If the user does not specify an optimization direction, the default optimization is to make the prompt structurally clear and with explicit instructions"
|
||||||
|
"Please directly output the optimized prompt without additional explanations. The optimized prompt should be directly usable with correct variable positions."
|
||||||
|
),
|
||||||
|
|
||||||
# base model limit
|
# base model limit
|
||||||
(RoleType.SYSTEM.value,
|
(RoleType.SYSTEM.value,
|
||||||
"Optimization Rules:\n"
|
"Optimization Rules:\n"
|
||||||
"1. Fully adjust the prompt content according to the user's requirements.\n"
|
"1. Fully adjust the prompt content according to the user's requirements.\n"
|
||||||
"2. When the user requests the insertion of variables, you must use Jinja2 syntax {{variable_name}} "
|
"When variables are required, use double curly braces {{variable_name}} as placeholders."
|
||||||
"(the variable name should be determined based on the user's requirement).\n"
|
"Variable names must be derived from the user's requirements.\n"
|
||||||
"3. Keep the prompt logic clear and instructions explicit.\n"
|
"3. Keep the prompt logic clear and instructions explicit.\n"
|
||||||
"4. Ensure that the modified prompt can be directly used.\n\n"
|
"4. Ensure that the modified prompt can be directly used.\n\n")
|
||||||
"Output Requirements:\n"
|
|
||||||
"Provide the result in JSON format, containing exactly two fields:\n"
|
|
||||||
" - prompt: The modified prompt (string).\n"
|
|
||||||
" - desc: A response addressing the user's optimization request (string).")
|
|
||||||
]
|
]
|
||||||
messages.extend(session_history[:-1]) # last message is current message
|
messages.extend(session_history[:-1]) # last message is current message
|
||||||
user_message_template = ChatPromptTemplate.from_messages([
|
user_message_template = ChatPromptTemplate.from_messages([
|
||||||
(RoleType.USER.value, "[current_prompt]\n{current_prompt}\n[user_require]\n{message}")
|
(RoleType.USER.value, "[original_prompt]\n{current_prompt}\n[user_require]\n{user_require}")
|
||||||
])
|
])
|
||||||
formatted_user_message = user_message_template.format(current_prompt=current_prompt, message=message)
|
formatted_user_message = user_message_template.format(current_prompt=current_prompt, user_require=user_require)
|
||||||
messages.extend([(RoleType.USER.value, formatted_user_message)])
|
messages.extend([(RoleType.USER.value, formatted_user_message)])
|
||||||
logger.info(f"Prompt optimization message: {messages}")
|
logger.info(f"Prompt optimization message: {messages}")
|
||||||
result = await llm.ainvoke(messages)
|
optim_prompt = await llm.ainvoke(messages)
|
||||||
try:
|
optim_desc = [
|
||||||
data_dict = json.loads(result.content)
|
(
|
||||||
model_resp = OptimizePromptResult.model_validate(data_dict)
|
RoleType.SYSTEM.value,
|
||||||
except Exception as e:
|
"You are a prompt optimization assistant.\n"
|
||||||
logger.error(f"Failed to parse model reponse to json - Error: {str(e)}", exc_info=True)
|
"Compare the original prompt, the user's requirements, "
|
||||||
raise BusinessException("Failed to parse model response", BizCode.PARSER_NOT_SUPPORTED)
|
"and the optimized prompt.\n"
|
||||||
return model_resp
|
"Summarize the changes made during optimization.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"1. Output must be a single short sentence.\n"
|
||||||
|
"2. Be concise and factual.\n"
|
||||||
|
"3. Do not explain the prompts themselves.\n"
|
||||||
|
"4. Do not include any extra text."
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"[Original Prompt]\n"
|
||||||
|
f"{current_prompt}\n\n"
|
||||||
|
"[User Requirements]\n"
|
||||||
|
f"{user_require}\n\n"
|
||||||
|
"[Optimized Prompt]\n"
|
||||||
|
f"{optim_prompt.content}"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
optim_desc = await llm.ainvoke(optim_desc)
|
||||||
|
|
||||||
|
return OptimizePromptResult(
|
||||||
|
prompt=optim_prompt.content,
|
||||||
|
desc=optim_desc.content
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
@@ -277,4 +276,3 @@ class PromptOptimizerService:
|
|||||||
content=content
|
content=content
|
||||||
)
|
)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Any, Annotated
|
from typing import Any, Annotated, AsyncGenerator
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
@@ -81,7 +81,7 @@ class WorkflowService:
|
|||||||
if not is_valid:
|
if not is_valid:
|
||||||
logger.warning(f"工作流配置验证失败: {errors}")
|
logger.warning(f"工作流配置验证失败: {errors}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.INVALID_PARAMETER,
|
code=BizCode.INVALID_PARAMETER,
|
||||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ class WorkflowService:
|
|||||||
config = self.get_workflow_config(app_id)
|
config = self.get_workflow_config(app_id)
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
code=BizCode.NOT_FOUND,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,7 +166,7 @@ class WorkflowService:
|
|||||||
if not is_valid:
|
if not is_valid:
|
||||||
logger.warning(f"工作流配置验证失败: {errors}")
|
logger.warning(f"工作流配置验证失败: {errors}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.INVALID_PARAMETER,
|
code=BizCode.INVALID_PARAMETER,
|
||||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -245,7 +245,7 @@ class WorkflowService:
|
|||||||
config = self.get_workflow_config(app_id)
|
config = self.get_workflow_config(app_id)
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
code=BizCode.NOT_FOUND,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -359,7 +359,7 @@ class WorkflowService:
|
|||||||
execution = self.get_execution(execution_id)
|
execution = self.get_execution(execution_id)
|
||||||
if not execution:
|
if not execution:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
code=BizCode.NOT_FOUND,
|
||||||
message=f"执行记录不存在: execution_id={execution_id}"
|
message=f"执行记录不存在: execution_id={execution_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -640,7 +640,7 @@ class WorkflowService:
|
|||||||
triggered_by: uuid.UUID,
|
triggered_by: uuid.UUID,
|
||||||
conversation_id: uuid.UUID | None = None,
|
conversation_id: uuid.UUID | None = None,
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
):
|
) -> AsyncGenerator | dict:
|
||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -660,7 +660,7 @@ class WorkflowService:
|
|||||||
config = self.get_workflow_config(app_id)
|
config = self.get_workflow_config(app_id)
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
code=BizCode.NOT_FOUND,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -687,7 +687,7 @@ class WorkflowService:
|
|||||||
app = self.db.query(App).filter(App.id == app_id).first()
|
app = self.db.query(App).filter(App.id == app_id).first()
|
||||||
if not app:
|
if not app:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
code=BizCode.NOT_FOUND,
|
||||||
message=f"应用不存在: app_id={app_id}"
|
message=f"应用不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -750,7 +750,7 @@ class WorkflowService:
|
|||||||
error_message=str(e)
|
error_message=str(e)
|
||||||
)
|
)
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
error_code=BizCode.INTERNAL_ERROR,
|
code=BizCode.INTERNAL_ERROR,
|
||||||
message=f"工作流执行失败: {str(e)}"
|
message=f"工作流执行失败: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user