fix(app):
1. Token consumption of the omni model; 2. Token consumption of the cluster includes sub-agents
This commit is contained in:
@@ -254,6 +254,33 @@ class LangChainAgent:
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tokens_from_message(msg) -> int:
|
||||||
|
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||||
|
- response_metadata.usage.total_tokens (部分 provider)
|
||||||
|
- usage_metadata.total_tokens (LangChain 新版)
|
||||||
|
"""
|
||||||
|
total = 0
|
||||||
|
# 1. response_metadata
|
||||||
|
response_meta = getattr(msg, "response_metadata", None)
|
||||||
|
if response_meta and isinstance(response_meta, dict):
|
||||||
|
# 尝试 token_usage 路径
|
||||||
|
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||||
|
if isinstance(token_usage, dict):
|
||||||
|
total = token_usage.get("total_tokens", 0)
|
||||||
|
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||||
|
if not total:
|
||||||
|
usage_meta = getattr(msg, "usage_metadata", None)
|
||||||
|
if usage_meta:
|
||||||
|
if isinstance(usage_meta, dict):
|
||||||
|
total = usage_meta.get("total_tokens", 0)
|
||||||
|
else:
|
||||||
|
total = getattr(usage_meta, "total_tokens", 0)
|
||||||
|
return total or 0
|
||||||
|
|
||||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
构建多模态消息内容
|
构建多模态消息内容
|
||||||
@@ -412,8 +439,7 @@ class LangChainAgent:
|
|||||||
else:
|
else:
|
||||||
content = str(msg.content)
|
content = str(msg.content)
|
||||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
total_tokens = self._extract_tokens_from_message(msg)
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||||
@@ -458,7 +484,7 @@ class LangChainAgent:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
memory_flag: Optional[bool] = True,
|
memory_flag: Optional[bool] = True,
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str | int, None]:
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -594,15 +620,13 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
# 统计token消耗
|
# 统计token消耗
|
||||||
|
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
|
||||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||||
total_tokens = response_meta.get("token_usage", {}).get(
|
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||||
"total_tokens",
|
yield stream_total_tokens
|
||||||
0
|
|
||||||
) if response_meta else 0
|
|
||||||
yield total_tokens
|
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class RedBearModelFactory:
|
|||||||
write=60.0,
|
write=60.0,
|
||||||
pool=10.0,
|
pool=10.0,
|
||||||
)
|
)
|
||||||
return {
|
params = {
|
||||||
"model": config.model_name,
|
"model": config.model_name,
|
||||||
"base_url": config.base_url,
|
"base_url": config.base_url,
|
||||||
"api_key": config.api_key,
|
"api_key": config.api_key,
|
||||||
@@ -66,6 +66,10 @@ class RedBearModelFactory:
|
|||||||
"max_retries": config.max_retries,
|
"max_retries": config.max_retries,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
|
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||||
|
if config.extra_params.get("streaming"):
|
||||||
|
params["stream_usage"] = True
|
||||||
|
return params
|
||||||
|
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||||
@@ -78,7 +82,7 @@ class RedBearModelFactory:
|
|||||||
write=60.0, # 写入超时:60秒
|
write=60.0, # 写入超时:60秒
|
||||||
pool=10.0, # 连接池超时:10秒
|
pool=10.0, # 连接池超时:10秒
|
||||||
)
|
)
|
||||||
return {
|
params = {
|
||||||
"model": config.model_name,
|
"model": config.model_name,
|
||||||
"base_url": config.base_url,
|
"base_url": config.base_url,
|
||||||
"api_key": config.api_key,
|
"api_key": config.api_key,
|
||||||
@@ -86,6 +90,10 @@ class RedBearModelFactory:
|
|||||||
"max_retries": config.max_retries,
|
"max_retries": config.max_retries,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
|
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||||
|
if config.extra_params.get("streaming"):
|
||||||
|
params["stream_usage"] = True
|
||||||
|
return params
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
# DashScope (通义千问) 使用自己的参数格式
|
# DashScope (通义千问) 使用自己的参数格式
|
||||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
|||||||
# 建立 SSE 连接
|
# 建立 SSE 连接
|
||||||
response = await self._session.get(self.server_url)
|
response = await self._session.get(self.server_url)
|
||||||
|
|
||||||
if response.status != 200:
|
if not (200 <= response.status < 300):
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
@@ -190,7 +190,7 @@ class SimpleMCPClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||||
if response.status != 200:
|
if not (200 <= response.status < 300):
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
@@ -205,7 +205,7 @@ class SimpleMCPClient:
|
|||||||
raise MCPConnectionError("endpoint URL 未初始化")
|
raise MCPConnectionError("endpoint URL 未初始化")
|
||||||
|
|
||||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||||
if response.status != 200:
|
if not (200 <= response.status < 300):
|
||||||
logger.warning(f"通知发送失败: {response.status}")
|
logger.warning(f"通知发送失败: {response.status}")
|
||||||
|
|
||||||
async def _initialize_modelscope_session(self):
|
async def _initialize_modelscope_session(self):
|
||||||
@@ -223,7 +223,7 @@ class SimpleMCPClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._session.post(self.server_url, json=init_request) as response:
|
async with self._session.post(self.server_url, json=init_request) as response:
|
||||||
if response.status != 200:
|
if not (200 <= response.status < 300):
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
|
|||||||
@@ -631,13 +631,13 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
if "sub_usage" in event:
|
# 拦截 sub_usage 事件,累加 token
|
||||||
|
if "event: sub_usage" in event:
|
||||||
if "data:" in event:
|
if "data:" in event:
|
||||||
try:
|
try:
|
||||||
data_line = event.split("data: ", 1)[1].strip()
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
data = json.loads(data_line)
|
data = json.loads(data_line)
|
||||||
if "total_tokens" in data:
|
total_tokens += data.get("total_tokens", 0)
|
||||||
total_tokens += data["total_tokens"]
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -403,6 +403,17 @@ class MasterAgentRouter:
|
|||||||
response = await llm.ainvoke(prompt)
|
response = await llm.ainvoke(prompt)
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||||
|
|
||||||
|
# 提取 token 消耗
|
||||||
|
self._last_routing_tokens = 0
|
||||||
|
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
||||||
|
um = response.usage_metadata
|
||||||
|
self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
|
||||||
|
elif hasattr(response, 'response_metadata') and response.response_metadata:
|
||||||
|
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
|
||||||
|
if isinstance(token_usage, dict):
|
||||||
|
self._last_routing_tokens = token_usage.get("total_tokens", 0)
|
||||||
|
logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}")
|
||||||
|
|
||||||
# 提取响应内容
|
# 提取响应内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
return response.content
|
return response.content
|
||||||
|
|||||||
@@ -287,6 +287,11 @@ class MultiAgentOrchestrator:
|
|||||||
sub_conversation_id = None
|
sub_conversation_id = None
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
|
# 累加 Master Agent 路由决策消耗的 token
|
||||||
|
total_tokens += task_analysis.get("routing_tokens", 0)
|
||||||
|
# 累加 Master Agent 整合消耗的 token
|
||||||
|
total_tokens += getattr(self, '_last_merge_tokens', 0)
|
||||||
|
|
||||||
if isinstance(results, dict):
|
if isinstance(results, dict):
|
||||||
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
||||||
# 提取 token 信息
|
# 提取 token 信息
|
||||||
@@ -358,12 +363,16 @@ class MultiAgentOrchestrator:
|
|||||||
variables=variables
|
variables=variables
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 获取路由决策消耗的 token
|
||||||
|
routing_tokens = getattr(self.router, '_last_routing_tokens', 0)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Master Agent 分析完成",
|
"Master Agent 分析完成",
|
||||||
extra={
|
extra={
|
||||||
"selected_agent": routing_decision.get("selected_agent_id"),
|
"selected_agent": routing_decision.get("selected_agent_id"),
|
||||||
"confidence": routing_decision.get("confidence"),
|
"confidence": routing_decision.get("confidence"),
|
||||||
"strategy": routing_decision.get("strategy")
|
"strategy": routing_decision.get("strategy"),
|
||||||
|
"routing_tokens": routing_tokens
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -372,7 +381,8 @@ class MultiAgentOrchestrator:
|
|||||||
"variables": variables or {},
|
"variables": variables or {},
|
||||||
"sub_agents": self.config.sub_agents,
|
"sub_agents": self.config.sub_agents,
|
||||||
"initial_context": variables or {},
|
"initial_context": variables or {},
|
||||||
"routing_decision": routing_decision
|
"routing_decision": routing_decision,
|
||||||
|
"routing_tokens": routing_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _execute_sequential(
|
async def _execute_sequential(
|
||||||
@@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator:
|
|||||||
|
|
||||||
# 5. 流式执行子 Agent
|
# 5. 流式执行子 Agent
|
||||||
sub_conversation_id = None
|
sub_conversation_id = None
|
||||||
|
# Master Agent 路由决策消耗的 token,通过 sub_usage 事件发送给上层
|
||||||
|
routing_tokens = task_analysis.get("routing_tokens", 0)
|
||||||
|
if routing_tokens > 0:
|
||||||
|
yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens})
|
||||||
|
|
||||||
async for event in self._execute_sub_agent_stream(
|
async for event in self._execute_sub_agent_stream(
|
||||||
agent_data["config"],
|
agent_data["config"],
|
||||||
message,
|
message,
|
||||||
@@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 直接透传所有事件(包括 sub_usage),累加统一由上层处理
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
# 6. 如果有会话 ID,发送一个包含它的事件
|
# 6. 如果有会话 ID,发送一个包含它的事件
|
||||||
@@ -2612,6 +2628,17 @@ class MultiAgentOrchestrator:
|
|||||||
|
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||||
|
|
||||||
|
# 提取整合消耗的 token
|
||||||
|
merge_tokens = 0
|
||||||
|
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
||||||
|
um = response.usage_metadata
|
||||||
|
merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0)
|
||||||
|
elif hasattr(response, 'response_metadata') and response.response_metadata:
|
||||||
|
token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {})
|
||||||
|
if isinstance(token_usage, dict):
|
||||||
|
merge_tokens = token_usage.get("total_tokens", 0)
|
||||||
|
self._last_merge_tokens = merge_tokens
|
||||||
|
|
||||||
# 提取响应内容
|
# 提取响应内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
merged_response = response.content
|
merged_response = response.content
|
||||||
@@ -2621,7 +2648,8 @@ class MultiAgentOrchestrator:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Master Agent 整合完成",
|
"Master Agent 整合完成",
|
||||||
extra={
|
extra={
|
||||||
"merged_length": len(merged_response)
|
"merged_length": len(merged_response),
|
||||||
|
"merge_tokens": merge_tokens
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user