feat(workflow): add parameter extraction node
- Implemented ParameterExtractorNode to extract structured parameters from input text using LLM. - Supports dynamic Jinja2 prompt rendering with field descriptions and types. - Integrates with RedBearLLM and ModelConfigService for model retrieval. - Handles JSON repair and raises clear BusinessException on parsing errors.
This commit is contained in:
@@ -16,6 +16,7 @@ from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
@@ -32,4 +33,5 @@ __all__ = [
|
||||
"AssignerNode",
|
||||
"HttpRequestNode",
|
||||
"JinjaRenderNode",
|
||||
"ParameterExtractorNode"
|
||||
]
|
||||
|
||||
@@ -26,6 +26,7 @@ class NodeType(StrEnum):
|
||||
ASSIGNER = "assigner"
|
||||
JINJARENDER = "jinja-render"
|
||||
VAR_AGGREGATOR = "var-aggregator"
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
@@ -35,7 +36,8 @@ WorkflowNode = Union[
|
||||
HttpRequestNode,
|
||||
KnowledgeRetrievalNode,
|
||||
JinjaRenderNode,
|
||||
VariableAggregatorNode
|
||||
VariableAggregatorNode,
|
||||
ParameterExtractorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -57,7 +59,8 @@ class NodeFactory:
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.JINJARENDER: JinjaRenderNode,
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.node import ParameterExtractorNode
|
||||
|
||||
__all__ = ["ParameterExtractorNode", "ParameterExtractorNodeConfig"]
|
||||
54
api/app/core/workflow/nodes/parameter_extractor/config.py
Normal file
54
api/app/core/workflow/nodes/parameter_extractor/config.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import uuid
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class ParamVariableType(StrEnum):
|
||||
"""
|
||||
Enum for variable types that can be extracted as parameters.
|
||||
Each member represents a type that can be used in parameter extraction.
|
||||
"""
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
|
||||
|
||||
class ParamsConfig(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the parameter"
|
||||
)
|
||||
|
||||
type: ParamVariableType = Field(
|
||||
...,
|
||||
description="Type of the parameter"
|
||||
)
|
||||
|
||||
desc: str = Field(
|
||||
...,
|
||||
description="Description of the parameter"
|
||||
)
|
||||
|
||||
|
||||
class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||
model_id: uuid.UUID = Field(
|
||||
...,
|
||||
description="Unique identifier for the model"
|
||||
)
|
||||
|
||||
text: str = Field(
|
||||
...,
|
||||
description="The string to be extracted as a parameter"
|
||||
)
|
||||
|
||||
params: list[ParamsConfig] = Field(
|
||||
...,
|
||||
description="List of parameters"
|
||||
)
|
||||
165
api/app/core/workflow/nodes/parameter_extractor/node.py
Normal file
165
api/app/core/workflow/nodes/parameter_extractor/node.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
|
||||
import json_repair
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
"""
|
||||
Load system and user prompt templates from local prompt files.
|
||||
|
||||
Notes:
|
||||
- Templates are expected to be Jinja2 files.
|
||||
- Reading from disk each time ensures the latest template is used (could be cached if performance-critical).
|
||||
- Both templates must exist, otherwise an exception will be raised.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: system_prompt, user_prompt
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(
|
||||
os.path.join(current_dir, "prompt", "system_prompt.jinja2"),
|
||||
encoding='utf-8'
|
||||
) as f:
|
||||
system_prompt = f.read()
|
||||
with open(os.path.join(
|
||||
current_dir, "prompt", "user_prompt.jinja2"),
|
||||
encoding='utf-8'
|
||||
) as f:
|
||||
user_prompt = f.read()
|
||||
return system_prompt, user_prompt
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""
|
||||
Retrieve a configured LLM instance based on the model ID from database.
|
||||
|
||||
Responsibilities:
|
||||
- Validate that the model exists and has at least one API key configured.
|
||||
- Construct RedBearLLM instance with proper credentials and model type.
|
||||
- Raise clear BusinessException if configuration is invalid.
|
||||
|
||||
Returns:
|
||||
RedBearLLM: Configured LLM instance ready to be invoked.
|
||||
|
||||
Raises:
|
||||
BusinessException: If the model is missing or lacks valid API key.
|
||||
"""
|
||||
model_id = self.typed_config.model_id
|
||||
|
||||
with get_db_read() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
return llm
|
||||
|
||||
def _get_field_desc(self) -> dict[str, str]:
|
||||
"""
|
||||
Build a dictionary mapping each parameter name to its description.
|
||||
Useful for dynamically generating prompts for LLM.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Mapping of parameter names to descriptions.
|
||||
"""
|
||||
field_desc = {}
|
||||
for param in self.typed_config.params:
|
||||
field_desc[param.name] = param.desc
|
||||
return field_desc
|
||||
|
||||
def _get_field_type(self) -> dict[str, str]:
|
||||
"""
|
||||
Build a dictionary mapping each parameter name to its description.
|
||||
Useful for dynamically generating prompts for LLM.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Mapping of parameter names to descriptions.
|
||||
"""
|
||||
field_type = {}
|
||||
for param in self.typed_config.params:
|
||||
field_type[param.name] = param.type
|
||||
return field_type
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Main execution function for this node.
|
||||
|
||||
Workflow:
|
||||
1. Retrieve LLM instance with valid credentials.
|
||||
2. Render user prompt template with field descriptions, types, and input text.
|
||||
3. Send system and user prompts to LLM asynchronously.
|
||||
4. Repair LLM JSON output safely.
|
||||
5. Return output dictionary.
|
||||
|
||||
Notes:
|
||||
- JSON repair is used to handle minor formatting errors in LLM output.
|
||||
- Exceptions are raised explicitly if parsing fails, to prevent silent workflow failures.
|
||||
- Rendering uses self._render_template for dynamic substitution from workflow state.
|
||||
|
||||
Args:
|
||||
state (WorkflowState): Current state of the workflow, used for template rendering.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
|
||||
|
||||
Raises:
|
||||
BusinessException: If LLM output cannot be parsed as valid JSON.
|
||||
"""
|
||||
llm = self._get_llm_instance()
|
||||
system_prompt, user_prompt = self._get_prompt()
|
||||
|
||||
user_prompt_teplate = Template(user_prompt)
|
||||
rendered_user_prompt = user_prompt_teplate.render(
|
||||
field_descriptions=str(self._get_field_desc()),
|
||||
field_type=str(self._get_field_type()),
|
||||
text_input=self._render_template(self.typed_config.text, state)
|
||||
)
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", rendered_user_prompt),
|
||||
]
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
result = json_repair.repair_json(model_resp.content)
|
||||
|
||||
return {
|
||||
"output": result,
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
You are an information extraction engine.
|
||||
|
||||
Your task is to extract structured data from text and output a valid JSON object (json).
|
||||
|
||||
Rules:
|
||||
- Output MUST be a single valid JSON object.
|
||||
- No explanations, markdown, code blocks, or extra text.
|
||||
- Never invent information.
|
||||
- If a field is not found, use null.
|
||||
- Follow the output structure exactly.
|
||||
@@ -0,0 +1,12 @@
|
||||
Extract structured information from the following text.
|
||||
|
||||
Field Descriptions:
|
||||
{{field_descriptions}}
|
||||
|
||||
Output Structure:
|
||||
{{field_type}}
|
||||
|
||||
Input Text:
|
||||
{{text_input}}
|
||||
|
||||
Output:
|
||||
@@ -23,6 +23,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# Dependency to get a DB session
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
@@ -35,6 +36,7 @@ def get_db():
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_context() -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -54,12 +56,16 @@ def get_db_context() -> Generator[Session, None, None]:
|
||||
db.rollback()
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_read() -> Generator[Session, None, None]:
|
||||
"""只读场景专用,出上下文自动 rollback,绝不留下 idle in transaction"""
|
||||
with get_db_context() as db:
|
||||
yield db
|
||||
db.rollback() # 只读任务无需 commit
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.rollback() # 只读任务无需 commit
|
||||
|
||||
|
||||
def get_pool_status():
|
||||
"""获取连接池状态(用于监控)"""
|
||||
@@ -70,5 +76,6 @@ def get_pool_status():
|
||||
"checked_out": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
"total": pool.size() + pool.overflow(),
|
||||
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (pool.size() + pool.overflow()) > 0 else 0
|
||||
}
|
||||
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (
|
||||
pool.size() + pool.overflow()) > 0 else 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user