Merge #17 into develop from feature/20251219_myh

feat(workflow): add conditional branch (If-Else) node

* feature/20251219_myh: (10 commits)
  fix(workflow): fix run_workflow streaming issues
  fix(prompt-optimizer): switch to built-in system prompt
  feat(workflow): add conditional branch (If-Else) node
  perf(types): add Union type declaration for workflow nodes
  fix(expression-eval): fix variable extraction issue in Jinja2 templates
  docs(samples): add config example for If-Else node
  style(workflow): update condition edge comments for conditional nodes
  style(enums): correct enum class name spelling
  refactor(workflow): unify all enum classes in one file and restructure workflow...
  feat(workflow): add import for if-else node configuration

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/17
This commit is contained in:
朱文辉
2025-12-19 18:18:50 +08:00
16 changed files with 478 additions and 335 deletions

View File

@@ -117,7 +117,7 @@ async def get_prompt_opt(
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
message=data.message
user_require=data.message
)
service.create_message(
tenant_id=current_user.tenant_id,
@@ -136,35 +136,3 @@ async def get_prompt_opt(
return success(data=result_schema)
@router.put(
"/model",
summary="Create or update prompt model config",
response_model=ApiResponse
)
def set_system_prompt(
data: PromptOptModelSet = ...,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Create or update a system prompt model configuration for the tenant.
Args:
data (PromptOptModelSet): Model configuration data including model ID,
system prompt, and optional configuration ID
db (Session): Database session
current_user: Current user information
Returns:
UUID: The ID of the created or updated model configuration.
"""
if data.id is None:
data.id = uuid.uuid4()
model_config = PromptOptimizerService(db).create_update_model_config(
current_user.tenant_id,
data.id,
data.system_prompt
)
return success(data=model_config.id)

View File

@@ -473,7 +473,7 @@ async def run_workflow(
async def event_generator():
"""生成 SSE 事件"""
try:
async for event in service.run_workflow(
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,

View File

@@ -4,16 +4,17 @@
基于 LangGraph 的工作流执行引擎。
"""
import logging
import datetime
import logging
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
@@ -25,11 +26,11 @@ class WorkflowExecutor:
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
@@ -90,8 +91,6 @@ class WorkflowExecutor:
"error_node": None
}
def build_graph(self) -> CompiledStateGraph:
"""构建 LangGraph
@@ -112,19 +111,38 @@ class WorkflowExecutor:
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == "end":
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
@@ -165,14 +183,14 @@ class WorkflowExecutor:
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
@@ -196,8 +214,8 @@ class WorkflowExecutor:
return graph
async def execute(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
@@ -271,8 +289,8 @@ class WorkflowExecutor:
}
async def execute_stream(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
@@ -305,7 +323,7 @@ class WorkflowExecutor:
try:
async for chunk in graph.astream(
initial_state,
# subgraphs=True,
# subgraphs=True,
stream_mode="updates",
):
# print(chunk)
@@ -326,7 +344,6 @@ class WorkflowExecutor:
"token_usage": None
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
@@ -386,11 +403,11 @@ class WorkflowExecutor:
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
@@ -414,11 +431,11 @@ async def execute_workflow(
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)

View File

@@ -5,6 +5,7 @@
"""
import logging
import re
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
@@ -59,9 +60,10 @@ class ExpressionEvaluator:
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量

View File

@@ -4,13 +4,14 @@
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
__all__ = [
"BaseNode",
@@ -18,7 +19,9 @@ __all__ = [
"LLMNode",
"AgentNode",
"TransformNode",
"IfElseNode",
"StartNode",
"EndNode",
"NodeFactory",
"WorkflowNode"
]

View File

@@ -13,6 +13,7 @@ from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
__all__ = [
# 基础类
@@ -26,4 +27,5 @@ __all__ = [
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
]

View File

@@ -1,5 +1,6 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
@@ -13,3 +14,23 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"
class ComparisonOperator(StrEnum):
EMPTY = "empty"
NOT_EMPTY = "not_empty"
CONTAINS = "contains"
NOT_CONTAINS = "not_contains"
START_WITH = "startwith"
END_WITH = "endwith"
EQ = "eq"
NE = "ne"
LT = "lt"
LE = "le"
GT = "gt"
GE = "ge"
class LogicOperator(StrEnum):
AND = "and"
OR = "or"

View File

@@ -0,0 +1,5 @@
"""Condition Node"""
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.if_else.node import IfElseNode
__all__ = ["IfElseNode", "IfElseNodeConfig"]

View File

@@ -0,0 +1,97 @@
"""Condition Configuration"""
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
...,
description="Comparison operator used to evaluate the condition"
)
left: str = Field(
...,
description="Value to compare against"
)
right: str = Field(
...,
description="Value to compare with"
)
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
conditions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
)
class IfElseNodeConfig(BaseNodeConfig):
cases: list[ConditionBranchConfig] = Field(
...,
description="List of branch conditions or expressions"
)
@field_validator("cases")
@classmethod
def validate_case_number(cls, v, info):
if len(v) < 1:
raise ValueError("At least one cases are required")
return v
class Config:
json_schema_extra = {
"examples": [
{
"cases": [
# CASE1 / IF Branch
{
"logical_operator": "and",
"conditions": [
[
{
"left": "node.userinput.message",
"comparison_operator": "eq",
"right": "'123'"
},
{
"left": "node.userinput.test",
"comparison_operator": "eq",
"right": "True"
}
]
]
},
# CASE1 / ELIF Branch
{
"logical_operator": "or",
"conditions": [
[
{
"left": "node.userinput.test",
"comparison_operator": "eq",
"right": "False"
},
{
"left": "node.userinput.message",
"comparison_operator": "contains",
"right": "'123'"
}
]
]
}
# CASE3 / ELSE Branch
]
}
]
}

View File

@@ -0,0 +1,167 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startwith()
case ComparisonOperator.END_WITH:
return self._endwith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod
def _build_condition_expression(
condition: ConditionDetail,
) -> str:
"""
Build a single boolean condition expression string.
This method does NOT evaluate the condition.
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
"""
Build conditional edge expressions for the If-Else node.
This method does NOT evaluate any condition at runtime.
Instead, it converts each case branch into a Python boolean
expression string, which will later be attached to LangGraph
as conditional edges.
Each returned expression corresponds to one branch and is
evaluated in order. A fallback 'True' condition is appended
to ensure a default branch when no previous conditions match.
Returns:
list[str]: A list of Python boolean expression strings,
ordered by branch priority.
"""
branch_index = 0
conditions = []
for case_branch in self.typed_config.cases:
branch_index += 1
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.conditions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
conditions.append(combined_condition)
# Default fallback branch
conditions.append("True")
return conditions
async def execute(self, state: WorkflowState) -> Any:
"""
"""
expressions = self.build_conditional_edge_expressions()
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
return f'CASE{i+1}'
return f'CASE{len(expressions)}'

View File

@@ -5,18 +5,29 @@
"""
import logging
from typing import Any
from typing import Any, Union
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
logger = logging.getLogger(__name__)
WorkflowNode = Union[
BaseNode,
StartNode,
EndNode,
LLMNode,
IfElseNode,
AgentNode,
TransformNode,
]
class NodeFactory:
"""节点工厂
@@ -25,16 +36,17 @@ class NodeFactory:
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
_node_types: dict[str, type[WorkflowNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]):
"""注册新的节点类型
Args:
@@ -52,10 +64,10 @@ class NodeFactory:
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> WorkflowNode | None:
"""创建节点实例
Args:

View File

@@ -20,7 +20,7 @@ from .data_config_model import DataConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
from .prompt_optimizer_model import PromptOptimizerModelConfig, PromptOptimizerSession, PromptOptimizerSessionHistory
from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory
__all__ = [
"Tenants",
@@ -56,7 +56,6 @@ __all__ = [
"WorkflowExecution",
"WorkflowNodeExecution",
"RetrievalInfo",
"PromptOptimizerModelConfig",
"PromptOptimizerSession",
"PromptOptimizerSessionHistory"
]

View File

@@ -27,49 +27,6 @@ class RoleType(StrEnum):
ASSISTANT = "assistant"
class PromptOptimizerModelConfig(Base):
"""
Prompt Optimization Model Configuration.
This table stores system-level prompt configurations for each tenant.
The configuration defines the base system prompt used during prompt
optimization sessions and serves as a foundational instruction set
for the optimization process.
Each tenant may have one or more model configurations depending on
business requirements.
Table Name:
prompt_model_config
Columns:
id (UUID):
Primary key. Unique identifier for the prompt model configuration.
tenant_id (UUID):
Foreign key referencing `tenants.id`.
Identifies the tenant that owns this configuration.
system_prompt (Text):
The system-level prompt used to guide prompt optimization logic.
created_at (DateTime):
Timestamp indicating when the configuration was created.
updated_at (DateTime):
Timestamp indicating the last update time of the configuration.
Usage:
- Loaded when initializing a prompt optimization session
- Acts as the root system instruction for all subsequent prompts
"""
__tablename__ = "prompt_model_config"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# model_id = Column(UUID(as_uuid=True), nullable=False, comment="Model ID")
system_prompt = Column(Text, nullable=False, comment="System Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="Update Time")
class PromptOptimizerSession(Base):
"""
Prompt Optimization Session Registry.

View File

@@ -1,120 +1,15 @@
import uuid
from typing import Optional
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.prompt_optimizer_model import (
PromptOptimizerModelConfig,
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
)
db_logger = get_db_logger()
class PromptOptimizerModelConfigRepository:
"""Repository for managing prompt optimizer model configurations."""
def __init__(self, db: Session):
self.db = db
def get_by_tenant_id(self, tenant_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]:
"""
Retrieve the prompt optimizer model configuration for a specific tenant.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
Returns:
Optional[PromptOptimizerModelConfig]: The model configuration if found, else None.
"""
db_logger.debug(f"Get prompt optimization model configuration: tenant_id={tenant_id}")
try:
config = self.db.query(PromptOptimizerModelConfig).filter(
PromptOptimizerModelConfig.tenant_id == tenant_id,
# PromptOptimizerModelConfig.model_id == model_id
).first()
if config:
db_logger.debug(f"Prompt optimization model configuration found: (ID: {config.id})")
else:
db_logger.debug(f"Prompt optimization model configuration not found: tenant_id={tenant_id}")
return config
except Exception as e:
db_logger.error(
f"Error retrieving prompt optimization model configuration: tenant_id={tenant_id} - {str(e)}")
raise
def get_by_config_id(self, tenant_id: uuid.UUID, config_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]:
"""
Retrieve a specific prompt optimizer model configuration by config ID and tenant ID.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
config_id (uuid.UUID): The unique identifier of the model configuration.
Returns:
Optional[PromptOptimizerModelConfig]: The model configuration if found, else None.
"""
db_logger.debug(f"Get prompt optimization model configuration: config_id={config_id}, tenant_id={tenant_id}")
try:
model = self.db.query(PromptOptimizerModelConfig).filter(
PromptOptimizerModelConfig.tenant_id == tenant_id,
PromptOptimizerModelConfig.id == config_id
).first()
if model:
db_logger.debug(f"Prompt optimization model configuration found: (ID: {model.id})")
else:
db_logger.debug(f"Prompt optimization model configuration not found: config_id={config_id}")
return model
except Exception as e:
db_logger.error(
f"Error retrieving prompt optimization model configuration: model_id={config_id} - {str(e)}")
raise
def create_or_update(
self,
config_id: uuid.UUID,
tenant_id: uuid.UUID,
system_prompt: str,
) -> Optional[PromptOptimizerModelConfig]:
"""
Create a new or update an existing prompt optimizer model configuration.
If a configuration with the given config_id exists, it updates its system_prompt.
Otherwise, it creates a new configuration record.
Args:
config_id (uuid.UUID): The unique identifier for the configuration.
tenant_id (uuid.UUID): The tenant's unique identifier.
system_prompt (str): The system prompt content for prompt optimization.
Returns:
Optional[PromptOptimizerModelConfig]: The created or updated model configuration.
"""
db_logger.debug(f"Create/Update prompt optimization model configuration: tenant_id={tenant_id}")
existing_config = self.get_by_config_id(tenant_id, config_id)
if existing_config:
existing_config.system_prompt = system_prompt
self.db.commit()
self.db.refresh(existing_config)
db_logger.debug(f"Prompt optimization model configuration update: ID:{config_id}")
return existing_config
else:
config = PromptOptimizerModelConfig(
id=config_id,
# model_id=model_id,
tenant_id=tenant_id,
system_prompt=system_prompt
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
db_logger.debug(f"Prompt optimization model configuration created: ID:{config.id}")
return config
class PromptOptimizerSessionRepository:
"""Repository for managing prompt optimization sessions and session history."""

View File

@@ -1,4 +1,3 @@
import json
import re
import uuid
@@ -12,13 +11,11 @@ from app.core.models import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory
from app.models.prompt_optimizer_model import (
PromptOptimizerModelConfig,
PromptOptimizerSession,
RoleType
)
from app.repositories.model_repository import ModelConfigRepository
from app.repositories.prompt_optimizer_repository import (
PromptOptimizerModelConfigRepository,
PromptOptimizerSessionRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
@@ -34,32 +31,24 @@ class PromptOptimizerService:
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID
) -> tuple[PromptOptimizerModelConfig, ModelConfig]:
) -> ModelConfig:
"""
Retrieve the prompt optimizer model configuration and model configuration.
Retrieve the model configuration for a specific tenant.
This method retrieves the prompt optimizer model configuration associated
with the specified model ID and tenant. It also fetches the corresponding
model configuration.
This method fetches the model configuration associated with the given
tenant_id and model_id. If no configuration is found, a BusinessException
is raised.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
model_id (uuid.UUID): The unique identifier of the prompt optimization model.
model_id (uuid.UUID): The unique identifier of the model.
Returns:
tuple[PromptOptimzerModelConfig, ModelConfig]:
A tuple containing the prompt optimizer model configuration
and the corresponding model configuration.
ModelConfig: The corresponding model configuration object.
Raises:
BusinessException: If the prompt optimizer model configuration does not exist.
BusinessException: If the model configuration does not exist.
"""
prompt_config = PromptOptimizerModelConfigRepository(self.db).get_by_tenant_id(
tenant_id
)
if not prompt_config:
raise BusinessException("提示词模型配置不存在", BizCode.NOT_FOUND)
model = ModelConfigRepository.get_by_id(
self.db, model_id, tenant_id=tenant_id
@@ -67,35 +56,7 @@ class PromptOptimizerService:
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return prompt_config, model
def create_update_model_config(
self,
tenant_id: uuid.UUID,
config_id: uuid.UUID,
system_prompt: str,
) -> PromptOptimizerModelConfig:
"""
Create or update a prompt optimizer model configuration.
This method creates a new prompt optimizer model configuration or updates
an existing one identified by the given configuration ID. The configuration
defines the system prompt used for prompt optimization.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
config_id (uuid.UUID): The unique identifier of the configuration to create or update.
system_prompt (str): The system prompt content used for prompt optimization.
Returns:
PromptOptimzerModelConfig: The created or updated prompt optimizer model configuration.
"""
prompt_config = PromptOptimizerModelConfigRepository(self.db).create_or_update(
config_id=config_id,
tenant_id=tenant_id,
system_prompt=system_prompt,
)
return prompt_config
return model
def create_session(
self,
@@ -159,37 +120,46 @@ class PromptOptimizerService:
session_id: uuid.UUID,
user_id: uuid.UUID,
current_prompt: str,
message: str
user_require: str
) -> OptimizePromptResult:
"""
Optimize a prompt using a prompt optimizer LLM.
Optimize a user-provided prompt using a configured prompt optimizer LLM.
This method uses a configured prompt optimizer model to refine an existing
prompt based on the user's requirements. The optimized prompt is generated
according to predefined system rules, including Jinja2 variable syntax and
a strict JSON output format.
This method refines the original prompt according to the user's requirements,
generating an optimized version that is directly usable by AI tools. The
optimization process follows strict rules, including:
- Wrapping user-inserted variables in double curly braces {{}}.
- Adhering to Jinja2 variable syntax if applicable.
- Ensuring a clear logic flow, explicit instructions, and strong executability.
- Producing output in a strict JSON format.
Steps performed:
1. Retrieve the model configuration for the given tenant and model.
2. Fetch the session message history for context.
3. Instantiate the LLM with the appropriate API key and model configuration.
4. Build system messages outlining optimization rules.
5. Format the user's original prompt and requirements as a user message.
6. Send messages to the LLM to generate the optimized prompt.
7. Generate a concise description summarizing the changes made during optimization.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
model_id (uuid.UUID): The unique identifier of the prompt optimizer model.
session_id (uuid.UUID): The unique identifier of the prompt optimization session.
user_id (uuid.UUID): The unique identifier of the user associated with the session.
current_prompt (str): The original prompt to be optimized.
message (str): The user's requirements or modification instructions.
tenant_id (uuid.UUID): Tenant identifier.
model_id (uuid.UUID): Prompt optimizer model identifier.
session_id (uuid.UUID): Prompt optimization session identifier.
user_id (uuid.UUID): Identifier of the user associated with the session.
current_prompt (str): Original prompt to optimize.
user_require (str): User's requirements or instructions for optimization.
Returns:
dict: A dictionary containing the optimized prompt and the description
of changes, in the following format:
{
"prompt": "<optimized_prompt>",
"desc": "<change_description>"
}
OptimizePromptResult: An object containing:
- prompt: The optimized prompt string.
- desc: A short description summarizing the changes.
Raises:
BusinessException: If the model response cannot be parsed as valid JSON
BusinessException: If the LLM response cannot be parsed as valid JSON
or does not conform to the expected output format.
"""
prompt_config, model_config = self.get_model_config(tenant_id, model_id)
model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
# Create LLM instance
@@ -204,36 +174,65 @@ class PromptOptimizerService:
# build message
messages = [
# init system_prompt
(RoleType.SYSTEM.value, prompt_config.system_prompt),
(
RoleType.SYSTEM.value,
"Your task is to optimize the original prompt provided by the user so that it can be directly used by AI tools,"
"and the variables that the user needs to insert must be wrapped in {{}}. "
"The optimized prompt should align with the optimization direction specified by the user (if any) and ensure clear logic, explicit instructions, and strong executability. "
"Please follow these rules when optimizing: "
'1. Ensure variables are wrapped in {{}}, e.g., optimize "Please enter your question" to "Please enter your {{question}}"'
"2. Instructions must be specific and operable, avoiding vague expressions"
"3. If the original prompt lacks key elements (such as output format requirements), supplement them completely "
"4. Keep the language concise and avoid redundancy "
"5. If the user does not specify an optimization direction, the default optimization is to make the prompt structurally clear and with explicit instructions"
"Please directly output the optimized prompt without additional explanations. The optimized prompt should be directly usable with correct variable positions."
),
# base model limit
(RoleType.SYSTEM.value,
"Optimization Rules:\n"
"1. Fully adjust the prompt content according to the user's requirements.\n"
"2. When the user requests the insertion of variables, you must use Jinja2 syntax {{variable_name}} "
"(the variable name should be determined based on the user's requirement).\n"
"When variables are required, use double curly braces {{variable_name}} as placeholders."
"Variable names must be derived from the user's requirements.\n"
"3. Keep the prompt logic clear and instructions explicit.\n"
"4. Ensure that the modified prompt can be directly used.\n\n"
"Output Requirements:\n"
"Provide the result in JSON format, containing exactly two fields:\n"
" - prompt: The modified prompt (string).\n"
" - desc: A response addressing the user's optimization request (string).")
"4. Ensure that the modified prompt can be directly used.\n\n")
]
messages.extend(session_history[:-1]) # last message is current message
user_message_template = ChatPromptTemplate.from_messages([
(RoleType.USER.value, "[current_prompt]\n{current_prompt}\n[user_require]\n{message}")
(RoleType.USER.value, "[original_prompt]\n{current_prompt}\n[user_require]\n{user_require}")
])
formatted_user_message = user_message_template.format(current_prompt=current_prompt, message=message)
formatted_user_message = user_message_template.format(current_prompt=current_prompt, user_require=user_require)
messages.extend([(RoleType.USER.value, formatted_user_message)])
logger.info(f"Prompt optimization message: {messages}")
result = await llm.ainvoke(messages)
try:
data_dict = json.loads(result.content)
model_resp = OptimizePromptResult.model_validate(data_dict)
except Exception as e:
logger.error(f"Failed to parse model reponse to json - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to parse model response", BizCode.PARSER_NOT_SUPPORTED)
return model_resp
optim_prompt = await llm.ainvoke(messages)
optim_desc = [
(
RoleType.SYSTEM.value,
"You are a prompt optimization assistant.\n"
"Compare the original prompt, the user's requirements, "
"and the optimized prompt.\n"
"Summarize the changes made during optimization.\n\n"
"Rules:\n"
"1. Output must be a single short sentence.\n"
"2. Be concise and factual.\n"
"3. Do not explain the prompts themselves.\n"
"4. Do not include any extra text."
),
(
"[Original Prompt]\n"
f"{current_prompt}\n\n"
"[User Requirements]\n"
f"{user_require}\n\n"
"[Optimized Prompt]\n"
f"{optim_prompt.content}"
)
]
optim_desc = await llm.ainvoke(optim_desc)
return OptimizePromptResult(
prompt=optim_prompt.content,
desc=optim_desc.content
)
@staticmethod
def parser_prompt_variables(prompt: str):
@@ -277,4 +276,3 @@ class PromptOptimizerService:
content=content
)
return message

View File

@@ -5,7 +5,7 @@ import json
import logging
import uuid
import datetime
from typing import Any, Annotated
from typing import Any, Annotated, AsyncGenerator
from sqlalchemy.orm import Session
from fastapi import Depends
@@ -81,7 +81,7 @@ class WorkflowService:
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
@@ -140,7 +140,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
@@ -166,7 +166,7 @@ class WorkflowService:
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
@@ -245,7 +245,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
@@ -359,7 +359,7 @@ class WorkflowService:
execution = self.get_execution(execution_id)
if not execution:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"执行记录不存在: execution_id={execution_id}"
)
@@ -640,7 +640,7 @@ class WorkflowService:
triggered_by: uuid.UUID,
conversation_id: uuid.UUID | None = None,
stream: bool = False
):
) -> AsyncGenerator | dict:
"""运行工作流
Args:
@@ -660,7 +660,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
@@ -687,7 +687,7 @@ class WorkflowService:
app = self.db.query(App).filter(App.id == app_id).first()
if not app:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"应用不存在: app_id={app_id}"
)
@@ -750,7 +750,7 @@ class WorkflowService:
error_message=str(e)
)
raise BusinessException(
error_code=BizCode.INTERNAL_ERROR,
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)