feat(workflow): add parameter extraction node

- Implemented ParameterExtractorNode to extract structured parameters from input text using LLM.
- Supports dynamic Jinja2 prompt rendering with field descriptions and types.
- Integrates with RedBearLLM and ModelConfigService for model retrieval.
- Handles JSON repair and raises clear BusinessException on parsing errors.
This commit is contained in:
mengyonghao
2025-12-26 14:39:17 +08:00
parent 7181d41e51
commit fc15a7793a
9 changed files with 264 additions and 6 deletions

View File

@@ -16,6 +16,7 @@ from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
__all__ = [
"BaseNode",
@@ -32,4 +33,5 @@ __all__ = [
"AssignerNode",
"HttpRequestNode",
"JinjaRenderNode",
"ParameterExtractorNode"
]

View File

@@ -26,6 +26,7 @@ class NodeType(StrEnum):
ASSIGNER = "assigner"
JINJARENDER = "jinja-render"
VAR_AGGREGATOR = "var-aggregator"
PARAMETER_EXTRACTOR = "parameter-extractor"
class ComparisonOperator(StrEnum):

View File

@@ -17,6 +17,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
@@ -35,7 +36,8 @@ WorkflowNode = Union[
HttpRequestNode,
KnowledgeRetrievalNode,
JinjaRenderNode,
VariableAggregatorNode
VariableAggregatorNode,
ParameterExtractorNode
]
@@ -57,7 +59,8 @@ class NodeFactory:
NodeType.ASSIGNER: AssignerNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.JINJARENDER: JinjaRenderNode,
NodeType.VAR_AGGREGATOR: VariableAggregatorNode
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
}
@classmethod

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.parameter_extractor.node import ParameterExtractorNode
__all__ = ["ParameterExtractorNode", "ParameterExtractorNodeConfig"]

View File

@@ -0,0 +1,54 @@
import uuid
from pydantic import Field, BaseModel
from enum import StrEnum
from app.core.workflow.nodes.base_config import BaseNodeConfig
class ParamVariableType(StrEnum):
"""
Enum for variable types that can be extracted as parameters.
Each member represents a type that can be used in parameter extraction.
"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
class ParamsConfig(BaseModel):
name: str = Field(
...,
description="Name of the parameter"
)
type: ParamVariableType = Field(
...,
description="Type of the parameter"
)
desc: str = Field(
...,
description="Description of the parameter"
)
class ParameterExtractorNodeConfig(BaseNodeConfig):
model_id: uuid.UUID = Field(
...,
description="Unique identifier for the model"
)
text: str = Field(
...,
description="The string to be extracted as a parameter"
)
params: list[ParamsConfig] = Field(
...,
description="List of parameters"
)

View File

@@ -0,0 +1,165 @@
import os
import json_repair
from typing import Any
from jinja2 import Template
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ParameterExtractorNodeConfig(**self.config)
@staticmethod
def _get_prompt():
"""
Load system and user prompt templates from local prompt files.
Notes:
- Templates are expected to be Jinja2 files.
- Reading from disk each time ensures the latest template is used (could be cached if performance-critical).
- Both templates must exist, otherwise an exception will be raised.
Returns:
Tuple[str, str]: system_prompt, user_prompt
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(
os.path.join(current_dir, "prompt", "system_prompt.jinja2"),
encoding='utf-8'
) as f:
system_prompt = f.read()
with open(os.path.join(
current_dir, "prompt", "user_prompt.jinja2"),
encoding='utf-8'
) as f:
user_prompt = f.read()
return system_prompt, user_prompt
def _get_llm_instance(self) -> RedBearLLM:
"""
Retrieve a configured LLM instance based on the model ID from database.
Responsibilities:
- Validate that the model exists and has at least one API key configured.
- Construct RedBearLLM instance with proper credentials and model type.
- Raise clear BusinessException if configuration is invalid.
Returns:
RedBearLLM: Configured LLM instance ready to be invoked.
Raises:
BusinessException: If the model is missing or lacks valid API key.
"""
model_id = self.typed_config.model_id
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
),
type=ModelType(model_type)
)
return llm
def _get_field_desc(self) -> dict[str, str]:
"""
Build a dictionary mapping each parameter name to its description.
Useful for dynamically generating prompts for LLM.
Returns:
dict[str, str]: Mapping of parameter names to descriptions.
"""
field_desc = {}
for param in self.typed_config.params:
field_desc[param.name] = param.desc
return field_desc
def _get_field_type(self) -> dict[str, str]:
"""
Build a dictionary mapping each parameter name to its description.
Useful for dynamically generating prompts for LLM.
Returns:
dict[str, str]: Mapping of parameter names to descriptions.
"""
field_type = {}
for param in self.typed_config.params:
field_type[param.name] = param.type
return field_type
async def execute(self, state: WorkflowState) -> Any:
"""
Main execution function for this node.
Workflow:
1. Retrieve LLM instance with valid credentials.
2. Render user prompt template with field descriptions, types, and input text.
3. Send system and user prompts to LLM asynchronously.
4. Repair LLM JSON output safely.
5. Return output dictionary.
Notes:
- JSON repair is used to handle minor formatting errors in LLM output.
- Exceptions are raised explicitly if parsing fails, to prevent silent workflow failures.
- Rendering uses self._render_template for dynamic substitution from workflow state.
Args:
state (WorkflowState): Current state of the workflow, used for template rendering.
Returns:
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
Raises:
BusinessException: If LLM output cannot be parsed as valid JSON.
"""
llm = self._get_llm_instance()
system_prompt, user_prompt = self._get_prompt()
user_prompt_teplate = Template(user_prompt)
rendered_user_prompt = user_prompt_teplate.render(
field_descriptions=str(self._get_field_desc()),
field_type=str(self._get_field_type()),
text_input=self._render_template(self.typed_config.text, state)
)
messages = [
("system", system_prompt),
("user", rendered_user_prompt),
]
model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content)
return {
"output": result,
}

View File

@@ -0,0 +1,10 @@
You are an information extraction engine.
Your task is to extract structured data from text and output a valid JSON object (json).
Rules:
- Output MUST be a single valid JSON object.
- No explanations, markdown, code blocks, or extra text.
- Never invent information.
- If a field is not found, use null.
- Follow the output structure exactly.

View File

@@ -0,0 +1,12 @@
Extract structured information from the following text.
Field Descriptions:
{{field_descriptions}}
Output Structure:
{{field_type}}
Input Text:
{{text_input}}
Output:

View File

@@ -23,6 +23,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Dependency to get a DB session
def get_db():
db = SessionLocal()
@@ -35,6 +36,7 @@ def get_db():
finally:
db.close()
@contextmanager
def get_db_context() -> Generator[Session, None, None]:
"""
@@ -54,12 +56,16 @@ def get_db_context() -> Generator[Session, None, None]:
db.rollback()
db.close()
@contextmanager
def get_db_read() -> Generator[Session, None, None]:
"""只读场景专用,出上下文自动 rollback绝不留下 idle in transaction"""
with get_db_context() as db:
yield db
db.rollback() # 只读任务无需 commit
try:
yield db
finally:
db.rollback() # 只读任务无需 commit
def get_pool_status():
"""获取连接池状态(用于监控)"""
@@ -70,5 +76,6 @@ def get_pool_status():
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"total": pool.size() + pool.overflow(),
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (pool.size() + pool.overflow()) > 0 else 0
}
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (
pool.size() + pool.overflow()) > 0 else 0
}