feat(workflow): add token usage statistics for question classifier and parameter extraction

This commit is contained in:
Eternity
2026-02-04 18:08:02 +08:00
parent 48aca996ff
commit 1c8a83140b
2 changed files with 26 additions and 0 deletions

View File

@@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
@staticmethod
def _get_prompt():
@@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode):
])
model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
@@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages)
result = response.content.strip()
self.response_metadata = response.response_metadata
if result in category_names:
category = result