diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 33fe040c..174fa877 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -16,6 +16,7 @@ from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode __all__ = [ "BaseNode", @@ -32,4 +33,5 @@ __all__ = [ "AssignerNode", "HttpRequestNode", "JinjaRenderNode", + "ParameterExtractorNode" ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 3ca59695..b4cc0634 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -26,6 +26,7 @@ class NodeType(StrEnum): ASSIGNER = "assigner" JINJARENDER = "jinja-render" VAR_AGGREGATOR = "var-aggregator" + PARAMETER_EXTRACTOR = "parameter-extractor" class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 9a5b5093..98c1468f 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -17,6 +17,7 @@ from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.jinja_render import JinjaRenderNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode @@ -35,7 +36,8 @@ WorkflowNode = Union[ HttpRequestNode, KnowledgeRetrievalNode, JinjaRenderNode, - VariableAggregatorNode + VariableAggregatorNode, + ParameterExtractorNode ] @@ -57,7 +59,8 @@ class NodeFactory: NodeType.ASSIGNER: AssignerNode, NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.JINJARENDER: JinjaRenderNode, - NodeType.VAR_AGGREGATOR: VariableAggregatorNode + NodeType.VAR_AGGREGATOR: VariableAggregatorNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, } @classmethod diff --git a/api/app/core/workflow/nodes/parameter_extractor/__init__.py b/api/app/core/workflow/nodes/parameter_extractor/__init__.py new file mode 100644 index 00000000..342ea27c --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.core.workflow.nodes.parameter_extractor.node import ParameterExtractorNode + +__all__ = ["ParameterExtractorNode", "ParameterExtractorNodeConfig"] diff --git a/api/app/core/workflow/nodes/parameter_extractor/config.py b/api/app/core/workflow/nodes/parameter_extractor/config.py new file mode 100644 index 00000000..30c0e1ef --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/config.py @@ -0,0 +1,54 @@ +import uuid + +from pydantic import Field, BaseModel +from enum import StrEnum + +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class ParamVariableType(StrEnum): + """ + Enum for variable types that can be extracted as parameters. + Each member represents a type that can be used in parameter extraction. + """ + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_OBJECT = "array[object]" + + +class ParamsConfig(BaseModel): + name: str = Field( + ..., + description="Name of the parameter" + ) + + type: ParamVariableType = Field( + ..., + description="Type of the parameter" + ) + + desc: str = Field( + ..., + description="Description of the parameter" + ) + + +class ParameterExtractorNodeConfig(BaseNodeConfig): + model_id: uuid.UUID = Field( + ..., + description="Unique identifier for the model" + ) + + text: str = Field( + ..., + description="The string to be extracted as a parameter" + ) + + params: list[ParamsConfig] = Field( + ..., + description="List of parameters" + ) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py new file mode 100644 index 00000000..f991e7dc --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -0,0 +1,165 @@ +import os + +import json_repair +from typing import Any + +from jinja2 import Template + +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.models import RedBearLLM, RedBearModelConfig +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.db import get_db_read +from app.models import ModelType +from app.services.model_service import ModelConfigService + + +class ParameterExtractorNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = ParameterExtractorNodeConfig(**self.config) + + @staticmethod + def _get_prompt(): + """ + Load system and user prompt templates from local prompt files. + + Notes: + - Templates are expected to be Jinja2 files. + - Reading from disk each time ensures the latest template is used (could be cached if performance-critical). + - Both templates must exist, otherwise an exception will be raised. + + Returns: + Tuple[str, str]: system_prompt, user_prompt + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + with open( + os.path.join(current_dir, "prompt", "system_prompt.jinja2"), + encoding='utf-8' + ) as f: + system_prompt = f.read() + with open(os.path.join( + current_dir, "prompt", "user_prompt.jinja2"), + encoding='utf-8' + ) as f: + user_prompt = f.read() + return system_prompt, user_prompt + + def _get_llm_instance(self) -> RedBearLLM: + """ + Retrieve a configured LLM instance based on the model ID from database. + + Responsibilities: + - Validate that the model exists and has at least one API key configured. + - Construct RedBearLLM instance with proper credentials and model type. + - Raise clear BusinessException if configuration is invalid. + + Returns: + RedBearLLM: Configured LLM instance ready to be invoked. + + Raises: + BusinessException: If the model is missing or lacks valid API key. + """ + model_id = self.typed_config.model_id + + with get_db_read() as db: + config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) + + if not config: + raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) + + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + api_base = api_config.api_base + model_type = config.type + + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base, + ), + type=ModelType(model_type) + ) + return llm + + def _get_field_desc(self) -> dict[str, str]: + """ + Build a dictionary mapping each parameter name to its description. + Useful for dynamically generating prompts for LLM. + + Returns: + dict[str, str]: Mapping of parameter names to descriptions. + """ + field_desc = {} + for param in self.typed_config.params: + field_desc[param.name] = param.desc + return field_desc + + def _get_field_type(self) -> dict[str, str]: + """ + Build a dictionary mapping each parameter name to its description. + Useful for dynamically generating prompts for LLM. + + Returns: + dict[str, str]: Mapping of parameter names to descriptions. + """ + field_type = {} + for param in self.typed_config.params: + field_type[param.name] = param.type + return field_type + + async def execute(self, state: WorkflowState) -> Any: + """ + Main execution function for this node. + + Workflow: + 1. Retrieve LLM instance with valid credentials. + 2. Render user prompt template with field descriptions, types, and input text. + 3. Send system and user prompts to LLM asynchronously. + 4. Repair LLM JSON output safely. + 5. Return output dictionary. + + Notes: + - JSON repair is used to handle minor formatting errors in LLM output. + - Exceptions are raised explicitly if parsing fails, to prevent silent workflow failures. + - Rendering uses self._render_template for dynamic substitution from workflow state. + + Args: + state (WorkflowState): Current state of the workflow, used for template rendering. + + Returns: + dict[str, Any]: Dictionary containing extracted parameters under the "output" key. + + Raises: + BusinessException: If LLM output cannot be parsed as valid JSON. + """ + llm = self._get_llm_instance() + system_prompt, user_prompt = self._get_prompt() + + user_prompt_teplate = Template(user_prompt) + rendered_user_prompt = user_prompt_teplate.render( + field_descriptions=str(self._get_field_desc()), + field_type=str(self._get_field_type()), + text_input=self._render_template(self.typed_config.text, state) + ) + + messages = [ + ("system", system_prompt), + ("user", rendered_user_prompt), + ] + + model_resp = await llm.ainvoke(messages) + result = json_repair.repair_json(model_resp.content) + + return { + "output": result, + } diff --git a/api/app/core/workflow/nodes/parameter_extractor/prompt/system_prompt.jinja2 b/api/app/core/workflow/nodes/parameter_extractor/prompt/system_prompt.jinja2 new file mode 100644 index 00000000..28d0e1f2 --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/prompt/system_prompt.jinja2 @@ -0,0 +1,10 @@ +You are an information extraction engine. + +Your task is to extract structured data from text and output a valid JSON object (json). + +Rules: +- Output MUST be a single valid JSON object. +- No explanations, markdown, code blocks, or extra text. +- Never invent information. +- If a field is not found, use null. +- Follow the output structure exactly. diff --git a/api/app/core/workflow/nodes/parameter_extractor/prompt/user_prompt.jinja2 b/api/app/core/workflow/nodes/parameter_extractor/prompt/user_prompt.jinja2 new file mode 100644 index 00000000..82f3940b --- /dev/null +++ b/api/app/core/workflow/nodes/parameter_extractor/prompt/user_prompt.jinja2 @@ -0,0 +1,12 @@ +Extract structured information from the following text. + +Field Descriptions: +{{field_descriptions}} + +Output Structure: +{{field_type}} + +Input Text: +{{text_input}} + +Output: \ No newline at end of file diff --git a/api/app/db.py b/api/app/db.py index 2513dc78..cdaa6dbd 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -23,6 +23,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + # Dependency to get a DB session def get_db(): db = SessionLocal() @@ -35,6 +36,7 @@ def get_db(): finally: db.close() + @contextmanager def get_db_context() -> Generator[Session, None, None]: """ @@ -54,12 +56,16 @@ def get_db_context() -> Generator[Session, None, None]: db.rollback() db.close() + @contextmanager def get_db_read() -> Generator[Session, None, None]: """只读场景专用,出上下文自动 rollback,绝不留下 idle in transaction""" with get_db_context() as db: - yield db - db.rollback() # 只读任务无需 commit + try: + yield db + finally: + db.rollback() # 只读任务无需 commit + def get_pool_status(): """获取连接池状态(用于监控)""" @@ -70,5 +76,6 @@ def get_pool_status(): "checked_out": pool.checkedout(), "overflow": pool.overflow(), "total": pool.size() + pool.overflow(), - "usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (pool.size() + pool.overflow()) > 0 else 0 - } \ No newline at end of file + "usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if ( + pool.size() + pool.overflow()) > 0 else 0 + }