fix(app):

1. Token consumption of the omni model;
2. Token consumption of the cluster includes sub-agents
This commit is contained in:
Timebomb2018
2026-03-30 18:37:09 +08:00
parent ed90405439
commit 876c39b1b0
6 changed files with 92 additions and 21 deletions

View File

@@ -254,6 +254,33 @@ class LangChainAgent:
return messages
@staticmethod
def _extract_tokens_from_message(msg) -> int:
"""从 AIMessage 或类似对象中提取 total_tokens兼容多种 provider 格式
支持的格式:
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
- response_metadata.usage.total_tokens (部分 provider)
- usage_metadata.total_tokens (LangChain 新版)
"""
total = 0
# 1. response_metadata
response_meta = getattr(msg, "response_metadata", None)
if response_meta and isinstance(response_meta, dict):
# 尝试 token_usage 路径
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
if isinstance(token_usage, dict):
total = token_usage.get("total_tokens", 0)
# 2. usage_metadataLangChain 新版 AIMessage 属性)
if not total:
usage_meta = getattr(msg, "usage_metadata", None)
if usage_meta:
if isinstance(usage_meta, dict):
total = usage_meta.get("total_tokens", 0)
else:
total = getattr(usage_meta, "total_tokens", 0)
return total or 0
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -412,8 +439,7 @@ class LangChainAgent:
else:
content = str(msg.content)
logger.debug(f"转换为字符串: {content[:100]}...")
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
total_tokens = self._extract_tokens_from_message(msg)
break
logger.info(f"最终提取的内容长度: {len(content)}")
@@ -458,7 +484,7 @@ class LangChainAgent:
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[str | int, None]:
"""执行流式对话
Args:
@@ -594,15 +620,13 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get(
"total_tokens",
0
) if response_meta else 0
yield total_tokens
stream_total_tokens = self._extract_tokens_from_message(msg)
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
yield stream_total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,

View File

@@ -58,7 +58,7 @@ class RedBearModelFactory:
write=60.0,
pool=10.0,
)
return {
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -66,6 +66,10 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
if config.extra_params.get("streaming"):
params["stream_usage"] = True
return params
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -78,7 +82,7 @@ class RedBearModelFactory:
write=60.0, # 写入超时60秒
pool=10.0, # 连接池超时10秒
)
return {
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -86,6 +90,10 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
if config.extra_params.get("streaming"):
params["stream_usage"] = True
return params
elif provider == ModelProvider.DASHSCOPE:
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数

View File

@@ -99,7 +99,7 @@ class SimpleMCPClient:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
@@ -190,7 +190,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
@@ -205,7 +205,7 @@ class SimpleMCPClient:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
if not (200 <= response.status < 300):
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
@@ -223,7 +223,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self.server_url, json=init_request) as response:
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")

View File

@@ -631,13 +631,13 @@ class AppChatService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
if "sub_usage" in event:
# 拦截 sub_usage 事件,累加 token
if "event: sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
total_tokens += data.get("total_tokens", 0)
except:
pass
else:

View File

@@ -403,6 +403,17 @@ class MasterAgentRouter:
response = await llm.ainvoke(prompt)
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取 token 消耗
self._last_routing_tokens = 0
if hasattr(response, 'usage_metadata') and response.usage_metadata:
um = response.usage_metadata
self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
elif hasattr(response, 'response_metadata') and response.response_metadata:
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
if isinstance(token_usage, dict):
self._last_routing_tokens = token_usage.get("total_tokens", 0)
logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}")
# 提取响应内容
if hasattr(response, 'content'):
return response.content

View File

@@ -287,6 +287,11 @@ class MultiAgentOrchestrator:
sub_conversation_id = None
total_tokens = 0
# 累加 Master Agent 路由决策消耗的 token
total_tokens += task_analysis.get("routing_tokens", 0)
# 累加 Master Agent 整合消耗的 token
total_tokens += getattr(self, '_last_merge_tokens', 0)
if isinstance(results, dict):
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
# 提取 token 信息
@@ -358,12 +363,16 @@ class MultiAgentOrchestrator:
variables=variables
)
# 获取路由决策消耗的 token
routing_tokens = getattr(self.router, '_last_routing_tokens', 0)
logger.info(
"Master Agent 分析完成",
extra={
"selected_agent": routing_decision.get("selected_agent_id"),
"confidence": routing_decision.get("confidence"),
"strategy": routing_decision.get("strategy")
"strategy": routing_decision.get("strategy"),
"routing_tokens": routing_tokens
}
)
@@ -372,7 +381,8 @@ class MultiAgentOrchestrator:
"variables": variables or {},
"sub_agents": self.config.sub_agents,
"initial_context": variables or {},
"routing_decision": routing_decision
"routing_decision": routing_decision,
"routing_tokens": routing_tokens
}
async def _execute_sequential(
@@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator:
# 5. 流式执行子 Agent
sub_conversation_id = None
# Master Agent 路由决策消耗的 token通过 sub_usage 事件发送给上层
routing_tokens = task_analysis.get("routing_tokens", 0)
if routing_tokens > 0:
yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens})
async for event in self._execute_sub_agent_stream(
agent_data["config"],
message,
@@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator:
except:
pass
# 直接透传所有事件(包括 sub_usage累加统一由上层处理
yield event
# 6. 如果有会话 ID发送一个包含它的事件
@@ -2612,6 +2628,17 @@ class MultiAgentOrchestrator:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取整合消耗的 token
merge_tokens = 0
if hasattr(response, 'usage_metadata') and response.usage_metadata:
um = response.usage_metadata
merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
elif hasattr(response, 'response_metadata') and response.response_metadata:
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
if isinstance(token_usage, dict):
merge_tokens = token_usage.get("total_tokens", 0)
self._last_merge_tokens = merge_tokens
# 提取响应内容
if hasattr(response, 'content'):
merged_response = response.content
@@ -2621,7 +2648,8 @@ class MultiAgentOrchestrator:
logger.info(
"Master Agent 整合完成",
extra={
"merged_length": len(merged_response)
"merged_length": len(merged_response),
"merge_tokens": merge_tokens
}
)