diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 33fe040c..174fa877 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -16,6 +16,7 @@ from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode __all__ = [ "BaseNode", @@ -32,4 +33,5 @@ __all__ = [ "AssignerNode", "HttpRequestNode", "JinjaRenderNode", + "ParameterExtractorNode" ] diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 12bb18cb..a8363421 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -19,6 +19,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig +from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig __all__ = [ # 基础类 @@ -38,4 +39,5 @@ __all__ = [ "HttpRequestNodeConfig", "JinjaRenderNodeConfig", "VariableAggregatorNodeConfig", + "ParameterExtractorNodeConfig", ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 3ca59695..b4cc0634 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -26,6 +26,7 @@ class NodeType(StrEnum): ASSIGNER = "assigner" JINJARENDER = "jinja-render" VAR_AGGREGATOR = "var-aggregator" + PARAMETER_EXTRACTOR = "parameter-extractor" class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 9a5b5093..98c1468f 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -17,6 +17,7 @@ from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.jinja_render import JinjaRenderNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode @@ -35,7 +36,8 @@ WorkflowNode = Union[ HttpRequestNode, KnowledgeRetrievalNode, JinjaRenderNode, - VariableAggregatorNode + VariableAggregatorNode, + ParameterExtractorNode ] @@ -57,7 +59,8 @@ class NodeFactory: NodeType.ASSIGNER: AssignerNode, NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.JINJARENDER: JinjaRenderNode, - NodeType.VAR_AGGREGATOR: VariableAggregatorNode + NodeType.VAR_AGGREGATOR: VariableAggregatorNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, } @classmethod diff --git a/api/app/core/workflow/nodes/parameter_extractor/__init__.py b/api/app/core/workflow/nodes/parameter_extractor/__init__.py new file mode 100644 index 00000000..342ea27c --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.core.workflow.nodes.parameter_extractor.node import ParameterExtractorNode + +__all__ = ["ParameterExtractorNode", "ParameterExtractorNodeConfig"] diff --git a/api/app/core/workflow/nodes/parameter_extractor/config.py b/api/app/core/workflow/nodes/parameter_extractor/config.py new file mode 100644 index 00000000..30c0e1ef --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/config.py @@ -0,0 +1,54 @@ +import uuid + +from pydantic import Field, BaseModel +from enum import StrEnum + +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class ParamVariableType(StrEnum): + """ + Enum for variable types that can be extracted as parameters. + Each member represents a type that can be used in parameter extraction. + """ + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_OBJECT = "array[object]" + + +class ParamsConfig(BaseModel): + name: str = Field( + ..., + description="Name of the parameter" + ) + + type: ParamVariableType = Field( + ..., + description="Type of the parameter" + ) + + desc: str = Field( + ..., + description="Description of the parameter" + ) + + +class ParameterExtractorNodeConfig(BaseNodeConfig): + model_id: uuid.UUID = Field( + ..., + description="Unique identifier for the model" + ) + + text: str = Field( + ..., + description="The string to be extracted as a parameter" + ) + + params: list[ParamsConfig] = Field( + ..., + description="List of parameters" + ) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py new file mode 100644 index 00000000..0eb3bfd4 --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -0,0 +1,165 @@ +import os + +import json_repair +from typing import Any + +from jinja2 import Template + +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.models import RedBearLLM, RedBearModelConfig +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.db import get_db_read +from app.models import ModelType +from app.services.model_service import ModelConfigService + + +class ParameterExtractorNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = ParameterExtractorNodeConfig(**self.config) + + @staticmethod + def _get_prompt(): + """ + Load system and user prompt templates from local prompt files. + + Notes: + - Templates are expected to be Jinja2 files. + - Reading from disk each time ensures the latest template is used (could be cached if performance-critical). + - Both templates must exist, otherwise an exception will be raised. + + Returns: + Tuple[str, str]: system_prompt, user_prompt + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + with open( + os.path.join(current_dir, "prompt", "system_prompt.jinja2"), + encoding='utf-8' + ) as f: + system_prompt = f.read() + with open(os.path.join( + current_dir, "prompt", "user_prompt.jinja2"), + encoding='utf-8' + ) as f: + user_prompt = f.read() + return system_prompt, user_prompt + + def _get_llm_instance(self) -> RedBearLLM: + """ + Retrieve a configured LLM instance based on the model ID from database. + + Responsibilities: + - Validate that the model exists and has at least one API key configured. + - Construct RedBearLLM instance with proper credentials and model type. + - Raise clear BusinessException if configuration is invalid. + + Returns: + RedBearLLM: Configured LLM instance ready to be invoked. + + Raises: + BusinessException: If the model is missing or lacks valid API key. + """ + model_id = self.typed_config.model_id + + with get_db_read() as db: + config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) + + if not config: + raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) + + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + api_base = api_config.api_base + model_type = config.type + + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base, + ), + type=ModelType(model_type) + ) + return llm + + def _get_field_desc(self) -> dict[str, str]: + """ + Build a dictionary mapping each parameter name to its description. + Useful for dynamically generating prompts for LLM. + + Returns: + dict[str, str]: Mapping of parameter names to descriptions. + """ + field_desc = {} + for param in self.typed_config.params: + field_desc[param.name] = param.desc + return field_desc + + def _get_field_type(self) -> dict[str, str]: + """ + Build a dictionary mapping each parameter name to its description. + Useful for dynamically generating prompts for LLM. + + Returns: + dict[str, str]: Mapping of parameter names to descriptions. + """ + field_type = {} + for param in self.typed_config.params: + field_type[param.name] = param.type + return field_type + + async def execute(self, state: WorkflowState) -> Any: + """ + Main execution function for this node. + + Workflow: + 1. Retrieve LLM instance with valid credentials. + 2. Render user prompt template with field descriptions, types, and input text. + 3. Send system and user prompts to LLM asynchronously. + 4. Repair LLM JSON output safely. + 5. Return output dictionary. + + Notes: + - JSON repair is used to handle minor formatting errors in LLM output. + - Exceptions are raised explicitly if parsing fails, to prevent silent workflow failures. + - Rendering uses self._render_template for dynamic substitution from workflow state. + + Args: + state (WorkflowState): Current state of the workflow, used for template rendering. + + Returns: + dict[str, Any]: Dictionary containing extracted parameters under the "output" key. + + Raises: + BusinessException: If LLM output cannot be parsed as valid JSON. + """ + llm = self._get_llm_instance() + system_prompt, user_prompt = self._get_prompt() + + user_prompt_teplate = Template(user_prompt) + rendered_user_prompt = user_prompt_teplate.render( + field_descriptions=str(self._get_field_desc()), + field_type=str(self._get_field_type()), + text_input=self._render_template(self.typed_config.text, state) + ) + + messages = [ + ("system", system_prompt), + ("user", rendered_user_prompt), + ] + + model_resp = await llm.ainvoke(messages) + result = json_repair.repair_json(model_resp.content) + + return { + "output": result, + } diff --git a/api/app/core/workflow/nodes/parameter_extractor/prompt/system_prompt.jinja2 b/api/app/core/workflow/nodes/parameter_extractor/prompt/system_prompt.jinja2 new file mode 100644 index 00000000..28d0e1f2 --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/prompt/system_prompt.jinja2 @@ -0,0 +1,10 @@ +You are an information extraction engine. + +Your task is to extract structured data from text and output a valid JSON object (json). + +Rules: +- Output MUST be a single valid JSON object. +- No explanations, markdown, code blocks, or extra text. +- Never invent information. +- If a field is not found, use null. +- Follow the output structure exactly. diff --git a/api/app/core/workflow/nodes/parameter_extractor/prompt/user_prompt.jinja2 b/api/app/core/workflow/nodes/parameter_extractor/prompt/user_prompt.jinja2 new file mode 100644 index 00000000..82f3940b --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/prompt/user_prompt.jinja2 @@ -0,0 +1,12 @@ +Extract structured information from the following text. + +Field Descriptions: +{{field_descriptions}} + +Output Structure: +{{field_type}} + +Input Text: +{{text_input}} + +Output: \ No newline at end of file diff --git a/api/app/db.py b/api/app/db.py index 2513dc78..cdaa6dbd 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -23,6 +23,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + # Dependency to get a DB session def get_db(): db = SessionLocal() @@ -35,6 +36,7 @@ def get_db(): finally: db.close() + @contextmanager def get_db_context() -> Generator[Session, None, None]: """ @@ -54,12 +56,16 @@ def get_db_context() -> Generator[Session, None, None]: db.rollback() db.close() + @contextmanager def get_db_read() -> Generator[Session, None, None]: """只读场景专用,出上下文自动 rollback,绝不留下 idle in transaction""" with get_db_context() as db: - yield db - db.rollback() # 只读任务无需 commit + try: + yield db + finally: + db.rollback() # 只读任务无需 commit + def get_pool_status(): """获取连接池状态(用于监控)""" @@ -70,5 +76,6 @@ def get_pool_status(): "checked_out": pool.checkedout(), "overflow": pool.overflow(), "total": pool.size() + pool.overflow(), - "usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (pool.size() + pool.overflow()) > 0 else 0 - } \ No newline at end of file + "usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if ( + pool.size() + pool.overflow()) > 0 else 0 + } diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 6af794b1..5f325a1b 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -1,8 +1,10 @@ import re import uuid +import json_repair from langchain_core.prompts import ChatPromptTemplate from sqlalchemy.orm import Session +from jinja2 import Template from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -170,68 +172,45 @@ class PromptOptimizerService: api_key=api_config.api_key, base_url=api_config.api_base ), type=ModelType(model_config.type)) + try: + with open('app/templates/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f: + opt_system_prompt = f.read() + rendered_system_message = Template(opt_system_prompt).render() + + with open('app/templates/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f: + opt_user_prompt = f.read() + except FileNotFoundError: + raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) + + except Exception as e: + logger.error(f"Failed to load system prompt template: {e}") + raise BusinessException(message="Internal server error", code=BizCode.INTERNAL_ERROR) + rendered_user_message = Template(opt_user_prompt).render( + current_prompt=current_prompt, + user_require=user_require + ) # build message messages = [ # init system_prompt ( RoleType.SYSTEM.value, - "Your task is to optimize the original prompt provided by the user so that it can be directly used by AI tools," - "and the variables that the user needs to insert must be wrapped in {{}}. " - "The optimized prompt should align with the optimization direction specified by the user (if any) and ensure clear logic, explicit instructions, and strong executability. " - "Please follow these rules when optimizing: " - '1. Ensure variables are wrapped in {{}}, e.g., optimize "Please enter your question" to "Please enter your {{question}}"' - "2. Instructions must be specific and operable, avoiding vague expressions" - "3. If the original prompt lacks key elements (such as output format requirements), supplement them completely " - "4. Keep the language concise and avoid redundancy " - "5. If the user does not specify an optimization direction, the default optimization is to make the prompt structurally clear and with explicit instructions" - "Please directly output the optimized prompt without additional explanations. The optimized prompt should be directly usable with correct variable positions." + rendered_system_message ), + ] - # base model limit - (RoleType.SYSTEM.value, - "Optimization Rules:\n" - "1. Fully adjust the prompt content according to the user's requirements.\n" - "When variables are required, use double curly braces {{variable_name}} as placeholders." - "Variable names must be derived from the user's requirements.\n" - "3. Keep the prompt logic clear and instructions explicit.\n" - "4. Ensure that the modified prompt can be directly used.\n\n") - ] messages.extend(session_history[:-1]) # last message is current message - user_message_template = ChatPromptTemplate.from_messages([ - (RoleType.USER.value, "[original_prompt]\n{current_prompt}\n[user_require]\n{user_require}") - ]) - formatted_user_message = user_message_template.format(current_prompt=current_prompt, user_require=user_require) - messages.extend([(RoleType.USER.value, formatted_user_message)]) + messages.extend([(RoleType.USER.value, rendered_user_message)]) logger.info(f"Prompt optimization message: {messages}") - optim_prompt = await llm.ainvoke(messages) - optim_desc = [ - ( - RoleType.SYSTEM.value, - "You are a prompt optimization assistant.\n" - "Compare the original prompt, the user's requirements, " - "and the optimized prompt.\n" - "Summarize the changes made during optimization.\n\n" - "Rules:\n" - "1. Output must be a single short sentence.\n" - "2. Be concise and factual.\n" - "3. Do not explain the prompts themselves.\n" - "4. Do not include any extra text." - ), - ( - "[Original Prompt]\n" - f"{current_prompt}\n\n" - "[User Requirements]\n" - f"{user_require}\n\n" - "[Optimized Prompt]\n" - f"{optim_prompt.content}" - ) - ] - optim_desc = await llm.ainvoke(optim_desc) + optim_resp = await llm.ainvoke(messages) + logger.info(optim_resp.content) + optim_result = json_repair.repair_json(optim_resp.content, return_objects=True) + prompt = optim_result.get("prompt") + desc = optim_result.get("desc") return OptimizePromptResult( - prompt=optim_prompt.content, - desc=optim_desc.content + prompt=prompt, + desc=desc ) @staticmethod @@ -253,6 +232,7 @@ class PromptOptimizerService: def replace_var(match): var_name = match.group(1) return variables.get(var_name, match.group(0)) + result = re.sub(pattern, replace_var, prompt) return result except Exception as e: diff --git a/api/app/templates/prompt/prompt_optimizer_system.jinja2 b/api/app/templates/prompt/prompt_optimizer_system.jinja2 new file mode 100644 index 00000000..ae19a6ab --- /dev/null +++ b/api/app/templates/prompt/prompt_optimizer_system.jinja2 @@ -0,0 +1,54 @@ +{% raw %} +Role: AI Prompt Optimization Expert + +Profile +description: An expert specialized in optimizing and generating prompts that can be directly used in AI tools, capable of transforming original prompts into a clear, immediately executable format based on user requirements. +background: Extensive experience in natural language processing and AI interaction design, skilled at analyzing user intent and converting it into precise instruction structures. +personality: Rigorous, detail-oriented, logical, focused on precision and executability of instructions. +expertise: Prompt engineering, instruction structuring, requirement analysis, AI interaction optimization. +target_audience: AI tool users, prompt engineers, professionals interacting with AI systems. + +Skills +Core Optimization Skills +Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt. +Structural Reconstruction: Transform vague requirements into clear, block-structured instructions. +Variable Handling: Identify and standardize dynamic variables in prompts. +Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs. + +Auxiliary Generation Skills +Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined. +Language Consistency: Maintain consistency between label language and user input language. +Executability Verification: Ensure optimized prompts can be directly used in AI tools. +Format Standardization: Strictly adhere to specified output format requirements. + +Rules +Basic Principles +Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements. +Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements. +Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels. +Language Rule: All label languages must fully match the user input language. + +Behavior Guidelines +Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity. +Readability Guideline: Ensure optimized prompts have good readability and logical flow. +Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed. +Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label. + +Constraints +Output Constraint: Must output in JSON format including the fields "prompt" and "desc". +Content Constraint: Must not include any explanations, analyses, or additional comments. +Language Constraint: Must use clear and concise language. +Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.). + +Workflows +Goal: Optimize or generate AI prompts that can be directly used according to user requirements. +Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}. +Step 2: Analyze requirements, identify conflicts, and prioritize current requirements. +Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined. +Step 4: Generate a JSON output containing the optimized prompt and its description. + +Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description. + +Initialization +As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows. +{% endraw %} \ No newline at end of file diff --git a/api/app/templates/prompt/prompt_optimizer_user.jinja2 b/api/app/templates/prompt/prompt_optimizer_user.jinja2 new file mode 100644 index 00000000..cbbc3249 --- /dev/null +++ b/api/app/templates/prompt/prompt_optimizer_user.jinja2 @@ -0,0 +1,5 @@ +[original_prompt] +{{current_prompt}} + +[user_require] +{{user_require}} \ No newline at end of file