fix(app):
1. Token consumption of the omni model; 2. Token consumption of the cluster includes sub-agents
This commit is contained in:
@@ -254,6 +254,33 @@ class LangChainAgent:
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _extract_tokens_from_message(msg) -> int:
|
||||
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||
|
||||
支持的格式:
|
||||
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||
- response_metadata.usage.total_tokens (部分 provider)
|
||||
- usage_metadata.total_tokens (LangChain 新版)
|
||||
"""
|
||||
total = 0
|
||||
# 1. response_metadata
|
||||
response_meta = getattr(msg, "response_metadata", None)
|
||||
if response_meta and isinstance(response_meta, dict):
|
||||
# 尝试 token_usage 路径
|
||||
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||
if not total:
|
||||
usage_meta = getattr(msg, "usage_metadata", None)
|
||||
if usage_meta:
|
||||
if isinstance(usage_meta, dict):
|
||||
total = usage_meta.get("total_tokens", 0)
|
||||
else:
|
||||
total = getattr(usage_meta, "total_tokens", 0)
|
||||
return total or 0
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -412,8 +439,7 @@ class LangChainAgent:
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
total_tokens = self._extract_tokens_from_message(msg)
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
@@ -458,7 +484,7 @@ class LangChainAgent:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[str | int, None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
@@ -594,15 +620,13 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get(
|
||||
"total_tokens",
|
||||
0
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||
yield stream_total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
|
||||
@@ -58,7 +58,7 @@ class RedBearModelFactory:
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -66,6 +66,10 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
if config.extra_params.get("streaming"):
|
||||
params["stream_usage"] = True
|
||||
return params
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
@@ -78,7 +82,7 @@ class RedBearModelFactory:
|
||||
write=60.0, # 写入超时:60秒
|
||||
pool=10.0, # 连接池超时:10秒
|
||||
)
|
||||
return {
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -86,6 +90,10 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
if config.extra_params.get("streaming"):
|
||||
params["stream_usage"] = True
|
||||
return params
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
|
||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status != 200:
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -190,7 +190,7 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -205,7 +205,7 @@ class SimpleMCPClient:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
if not (200 <= response.status < 300):
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
@@ -223,7 +223,7 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
|
||||
@@ -631,13 +631,13 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
if "sub_usage" in event:
|
||||
# 拦截 sub_usage 事件,累加 token
|
||||
if "event: sub_usage" in event:
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "total_tokens" in data:
|
||||
total_tokens += data["total_tokens"]
|
||||
total_tokens += data.get("total_tokens", 0)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -403,6 +403,17 @@ class MasterAgentRouter:
|
||||
response = await llm.ainvoke(prompt)
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取 token 消耗
|
||||
self._last_routing_tokens = 0
|
||||
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
||||
um = response.usage_metadata
|
||||
self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
|
||||
elif hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
self._last_routing_tokens = token_usage.get("total_tokens", 0)
|
||||
logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}")
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
|
||||
@@ -287,6 +287,11 @@ class MultiAgentOrchestrator:
|
||||
sub_conversation_id = None
|
||||
total_tokens = 0
|
||||
|
||||
# 累加 Master Agent 路由决策消耗的 token
|
||||
total_tokens += task_analysis.get("routing_tokens", 0)
|
||||
# 累加 Master Agent 整合消耗的 token
|
||||
total_tokens += getattr(self, '_last_merge_tokens', 0)
|
||||
|
||||
if isinstance(results, dict):
|
||||
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
||||
# 提取 token 信息
|
||||
@@ -358,12 +363,16 @@ class MultiAgentOrchestrator:
|
||||
variables=variables
|
||||
)
|
||||
|
||||
# 获取路由决策消耗的 token
|
||||
routing_tokens = getattr(self.router, '_last_routing_tokens', 0)
|
||||
|
||||
logger.info(
|
||||
"Master Agent 分析完成",
|
||||
extra={
|
||||
"selected_agent": routing_decision.get("selected_agent_id"),
|
||||
"confidence": routing_decision.get("confidence"),
|
||||
"strategy": routing_decision.get("strategy")
|
||||
"strategy": routing_decision.get("strategy"),
|
||||
"routing_tokens": routing_tokens
|
||||
}
|
||||
)
|
||||
|
||||
@@ -372,7 +381,8 @@ class MultiAgentOrchestrator:
|
||||
"variables": variables or {},
|
||||
"sub_agents": self.config.sub_agents,
|
||||
"initial_context": variables or {},
|
||||
"routing_decision": routing_decision
|
||||
"routing_decision": routing_decision,
|
||||
"routing_tokens": routing_tokens
|
||||
}
|
||||
|
||||
async def _execute_sequential(
|
||||
@@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator:
|
||||
|
||||
# 5. 流式执行子 Agent
|
||||
sub_conversation_id = None
|
||||
# Master Agent 路由决策消耗的 token,通过 sub_usage 事件发送给上层
|
||||
routing_tokens = task_analysis.get("routing_tokens", 0)
|
||||
if routing_tokens > 0:
|
||||
yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens})
|
||||
|
||||
async for event in self._execute_sub_agent_stream(
|
||||
agent_data["config"],
|
||||
message,
|
||||
@@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator:
|
||||
except:
|
||||
pass
|
||||
|
||||
# 直接透传所有事件(包括 sub_usage),累加统一由上层处理
|
||||
yield event
|
||||
|
||||
# 6. 如果有会话 ID,发送一个包含它的事件
|
||||
@@ -2612,6 +2628,17 @@ class MultiAgentOrchestrator:
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取整合消耗的 token
|
||||
merge_tokens = 0
|
||||
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
||||
um = response.usage_metadata
|
||||
merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
|
||||
elif hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
merge_tokens = token_usage.get("total_tokens", 0)
|
||||
self._last_merge_tokens = merge_tokens
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
merged_response = response.content
|
||||
@@ -2621,7 +2648,8 @@ class MultiAgentOrchestrator:
|
||||
logger.info(
|
||||
"Master Agent 整合完成",
|
||||
extra={
|
||||
"merged_length": len(merged_response)
|
||||
"merged_length": len(merged_response),
|
||||
"merge_tokens": merge_tokens
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user