feat(workflow): add token usage statistics for question classifier and parameter extraction
This commit is contained in:
@@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
if self.response_metadata:
|
||||||
|
usage = self.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_prompt():
|
def _get_prompt():
|
||||||
@@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
])
|
])
|
||||||
|
|
||||||
model_resp = await llm.ainvoke(messages)
|
model_resp = await llm.ainvoke(messages)
|
||||||
|
self.response_metadata = model_resp.response_metadata
|
||||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||||
logger.info(f"node: {self.node_id} get params:{result}")
|
logger.info(f"node: {self.node_id} get params:{result}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
if self.response_metadata:
|
||||||
|
usage = self.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
@@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
result = response.content.strip()
|
result = response.content.strip()
|
||||||
|
self.response_metadata = response.response_metadata
|
||||||
|
|
||||||
if result in category_names:
|
if result in category_names:
|
||||||
category = result
|
category = result
|
||||||
|
|||||||
Reference in New Issue
Block a user