feat(workflow): add parameter extraction node
- Implemented ParameterExtractorNode to extract structured parameters from input text using LLM. - Supports dynamic Jinja2 prompt rendering with field descriptions and types. - Integrates with RedBearLLM and ModelConfigService for model retrieval. - Handles JSON repair and raises clear BusinessException on parsing errors.
This commit is contained in:
@@ -16,6 +16,7 @@ from app.core.workflow.nodes.llm import LLMNode
|
|||||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
|
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseNode",
|
"BaseNode",
|
||||||
@@ -32,4 +33,5 @@ __all__ = [
|
|||||||
"AssignerNode",
|
"AssignerNode",
|
||||||
"HttpRequestNode",
|
"HttpRequestNode",
|
||||||
"JinjaRenderNode",
|
"JinjaRenderNode",
|
||||||
|
"ParameterExtractorNode"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class NodeType(StrEnum):
|
|||||||
ASSIGNER = "assigner"
|
ASSIGNER = "assigner"
|
||||||
JINJARENDER = "jinja-render"
|
JINJARENDER = "jinja-render"
|
||||||
VAR_AGGREGATOR = "var-aggregator"
|
VAR_AGGREGATOR = "var-aggregator"
|
||||||
|
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
|||||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||||
from app.core.workflow.nodes.llm import LLMNode
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
|
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||||
@@ -35,7 +36,8 @@ WorkflowNode = Union[
|
|||||||
HttpRequestNode,
|
HttpRequestNode,
|
||||||
KnowledgeRetrievalNode,
|
KnowledgeRetrievalNode,
|
||||||
JinjaRenderNode,
|
JinjaRenderNode,
|
||||||
VariableAggregatorNode
|
VariableAggregatorNode,
|
||||||
|
ParameterExtractorNode
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +59,8 @@ class NodeFactory:
|
|||||||
NodeType.ASSIGNER: AssignerNode,
|
NodeType.ASSIGNER: AssignerNode,
|
||||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||||
NodeType.JINJARENDER: JinjaRenderNode,
|
NodeType.JINJARENDER: JinjaRenderNode,
|
||||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode
|
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
|
||||||
|
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||||
|
from app.core.workflow.nodes.parameter_extractor.node import ParameterExtractorNode
|
||||||
|
|
||||||
|
__all__ = ["ParameterExtractorNode", "ParameterExtractorNodeConfig"]
|
||||||
54
api/app/core/workflow/nodes/parameter_extractor/config.py
Normal file
54
api/app/core/workflow/nodes/parameter_extractor/config.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ParamVariableType(StrEnum):
|
||||||
|
"""
|
||||||
|
Enum for variable types that can be extracted as parameters.
|
||||||
|
Each member represents a type that can be used in parameter extraction.
|
||||||
|
"""
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
ARRAY_STRING = "array[string]"
|
||||||
|
ARRAY_NUMBER = "array[number]"
|
||||||
|
ARRAY_BOOLEAN = "array[boolean]"
|
||||||
|
ARRAY_OBJECT = "array[object]"
|
||||||
|
|
||||||
|
|
||||||
|
class ParamsConfig(BaseModel):
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description="Name of the parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
type: ParamVariableType = Field(
|
||||||
|
...,
|
||||||
|
description="Type of the parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
desc: str = Field(
|
||||||
|
...,
|
||||||
|
description="Description of the parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||||
|
model_id: uuid.UUID = Field(
|
||||||
|
...,
|
||||||
|
description="Unique identifier for the model"
|
||||||
|
)
|
||||||
|
|
||||||
|
text: str = Field(
|
||||||
|
...,
|
||||||
|
description="The string to be extracted as a parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
params: list[ParamsConfig] = Field(
|
||||||
|
...,
|
||||||
|
description="List of parameters"
|
||||||
|
)
|
||||||
165
api/app/core/workflow/nodes/parameter_extractor/node.py
Normal file
165
api/app/core/workflow/nodes/parameter_extractor/node.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.workflow.nodes import WorkflowState
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||||
|
from app.db import get_db_read
|
||||||
|
from app.models import ModelType
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterExtractorNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_prompt():
|
||||||
|
"""
|
||||||
|
Load system and user prompt templates from local prompt files.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Templates are expected to be Jinja2 files.
|
||||||
|
- Reading from disk each time ensures the latest template is used (could be cached if performance-critical).
|
||||||
|
- Both templates must exist, otherwise an exception will be raised.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, str]: system_prompt, user_prompt
|
||||||
|
"""
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
with open(
|
||||||
|
os.path.join(current_dir, "prompt", "system_prompt.jinja2"),
|
||||||
|
encoding='utf-8'
|
||||||
|
) as f:
|
||||||
|
system_prompt = f.read()
|
||||||
|
with open(os.path.join(
|
||||||
|
current_dir, "prompt", "user_prompt.jinja2"),
|
||||||
|
encoding='utf-8'
|
||||||
|
) as f:
|
||||||
|
user_prompt = f.read()
|
||||||
|
return system_prompt, user_prompt
|
||||||
|
|
||||||
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
|
"""
|
||||||
|
Retrieve a configured LLM instance based on the model ID from database.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Validate that the model exists and has at least one API key configured.
|
||||||
|
- Construct RedBearLLM instance with proper credentials and model type.
|
||||||
|
- Raise clear BusinessException if configuration is invalid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedBearLLM: Configured LLM instance ready to be invoked.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: If the model is missing or lacks valid API key.
|
||||||
|
"""
|
||||||
|
model_id = self.typed_config.model_id
|
||||||
|
|
||||||
|
with get_db_read() as db:
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
|
api_config = config.api_keys[0]
|
||||||
|
model_name = api_config.model_name
|
||||||
|
provider = api_config.provider
|
||||||
|
api_key = api_config.api_key
|
||||||
|
api_base = api_config.api_base
|
||||||
|
model_type = config.type
|
||||||
|
|
||||||
|
llm = RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=model_name,
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
),
|
||||||
|
type=ModelType(model_type)
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
def _get_field_desc(self) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Build a dictionary mapping each parameter name to its description.
|
||||||
|
Useful for dynamically generating prompts for LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, str]: Mapping of parameter names to descriptions.
|
||||||
|
"""
|
||||||
|
field_desc = {}
|
||||||
|
for param in self.typed_config.params:
|
||||||
|
field_desc[param.name] = param.desc
|
||||||
|
return field_desc
|
||||||
|
|
||||||
|
def _get_field_type(self) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Build a dictionary mapping each parameter name to its description.
|
||||||
|
Useful for dynamically generating prompts for LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, str]: Mapping of parameter names to descriptions.
|
||||||
|
"""
|
||||||
|
field_type = {}
|
||||||
|
for param in self.typed_config.params:
|
||||||
|
field_type[param.name] = param.type
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
"""
|
||||||
|
Main execution function for this node.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Retrieve LLM instance with valid credentials.
|
||||||
|
2. Render user prompt template with field descriptions, types, and input text.
|
||||||
|
3. Send system and user prompts to LLM asynchronously.
|
||||||
|
4. Repair LLM JSON output safely.
|
||||||
|
5. Return output dictionary.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- JSON repair is used to handle minor formatting errors in LLM output.
|
||||||
|
- Exceptions are raised explicitly if parsing fails, to prevent silent workflow failures.
|
||||||
|
- Rendering uses self._render_template for dynamic substitution from workflow state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (WorkflowState): Current state of the workflow, used for template rendering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: If LLM output cannot be parsed as valid JSON.
|
||||||
|
"""
|
||||||
|
llm = self._get_llm_instance()
|
||||||
|
system_prompt, user_prompt = self._get_prompt()
|
||||||
|
|
||||||
|
user_prompt_teplate = Template(user_prompt)
|
||||||
|
rendered_user_prompt = user_prompt_teplate.render(
|
||||||
|
field_descriptions=str(self._get_field_desc()),
|
||||||
|
field_type=str(self._get_field_type()),
|
||||||
|
text_input=self._render_template(self.typed_config.text, state)
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
("system", system_prompt),
|
||||||
|
("user", rendered_user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
model_resp = await llm.ainvoke(messages)
|
||||||
|
result = json_repair.repair_json(model_resp.content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"output": result,
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
You are an information extraction engine.
|
||||||
|
|
||||||
|
Your task is to extract structured data from text and output a valid JSON object (json).
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Output MUST be a single valid JSON object.
|
||||||
|
- No explanations, markdown, code blocks, or extra text.
|
||||||
|
- Never invent information.
|
||||||
|
- If a field is not found, use null.
|
||||||
|
- Follow the output structure exactly.
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
Extract structured information from the following text.
|
||||||
|
|
||||||
|
Field Descriptions:
|
||||||
|
{{field_descriptions}}
|
||||||
|
|
||||||
|
Output Structure:
|
||||||
|
{{field_type}}
|
||||||
|
|
||||||
|
Input Text:
|
||||||
|
{{text_input}}
|
||||||
|
|
||||||
|
Output:
|
||||||
@@ -23,6 +23,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
# Dependency to get a DB session
|
# Dependency to get a DB session
|
||||||
def get_db():
|
def get_db():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
@@ -35,6 +36,7 @@ def get_db():
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_context() -> Generator[Session, None, None]:
|
def get_db_context() -> Generator[Session, None, None]:
|
||||||
"""
|
"""
|
||||||
@@ -54,12 +56,16 @@ def get_db_context() -> Generator[Session, None, None]:
|
|||||||
db.rollback()
|
db.rollback()
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_read() -> Generator[Session, None, None]:
|
def get_db_read() -> Generator[Session, None, None]:
|
||||||
"""只读场景专用,出上下文自动 rollback,绝不留下 idle in transaction"""
|
"""只读场景专用,出上下文自动 rollback,绝不留下 idle in transaction"""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
yield db
|
try:
|
||||||
db.rollback() # 只读任务无需 commit
|
yield db
|
||||||
|
finally:
|
||||||
|
db.rollback() # 只读任务无需 commit
|
||||||
|
|
||||||
|
|
||||||
def get_pool_status():
|
def get_pool_status():
|
||||||
"""获取连接池状态(用于监控)"""
|
"""获取连接池状态(用于监控)"""
|
||||||
@@ -70,5 +76,6 @@ def get_pool_status():
|
|||||||
"checked_out": pool.checkedout(),
|
"checked_out": pool.checkedout(),
|
||||||
"overflow": pool.overflow(),
|
"overflow": pool.overflow(),
|
||||||
"total": pool.size() + pool.overflow(),
|
"total": pool.size() + pool.overflow(),
|
||||||
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (pool.size() + pool.overflow()) > 0 else 0
|
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (
|
||||||
|
pool.size() + pool.overflow()) > 0 else 0
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user