fix(app):

1. Token consumption of the omni model;
2. Token consumption of the cluster includes sub-agents
This commit is contained in:
Timebomb2018
2026-03-30 18:37:09 +08:00
parent ed90405439
commit 876c39b1b0
6 changed files with 92 additions and 21 deletions

View File

@@ -254,6 +254,33 @@ class LangChainAgent:
return messages
@staticmethod
def _extract_tokens_from_message(msg) -> int:
"""从 AIMessage 或类似对象中提取 total_tokens兼容多种 provider 格式
支持的格式:
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
- response_metadata.usage.total_tokens (部分 provider)
- usage_metadata.total_tokens (LangChain 新版)
"""
total = 0
# 1. response_metadata
response_meta = getattr(msg, "response_metadata", None)
if response_meta and isinstance(response_meta, dict):
# 尝试 token_usage 路径
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
if isinstance(token_usage, dict):
total = token_usage.get("total_tokens", 0)
# 2. usage_metadataLangChain 新版 AIMessage 属性)
if not total:
usage_meta = getattr(msg, "usage_metadata", None)
if usage_meta:
if isinstance(usage_meta, dict):
total = usage_meta.get("total_tokens", 0)
else:
total = getattr(usage_meta, "total_tokens", 0)
return total or 0
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -412,8 +439,7 @@ class LangChainAgent:
else:
content = str(msg.content)
logger.debug(f"转换为字符串: {content[:100]}...")
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
total_tokens = self._extract_tokens_from_message(msg)
break
logger.info(f"最终提取的内容长度: {len(content)}")
@@ -458,7 +484,7 @@ class LangChainAgent:
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[str | int, None]:
"""执行流式对话
Args:
@@ -594,15 +620,13 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get(
"total_tokens",
0
) if response_meta else 0
yield total_tokens
stream_total_tokens = self._extract_tokens_from_message(msg)
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
yield stream_total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,

View File

@@ -58,7 +58,7 @@ class RedBearModelFactory:
write=60.0,
pool=10.0,
)
return {
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -66,6 +66,10 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
if config.extra_params.get("streaming"):
params["stream_usage"] = True
return params
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -78,7 +82,7 @@ class RedBearModelFactory:
write=60.0, # 写入超时60秒
pool=10.0, # 连接池超时10秒
)
return {
params = {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
@@ -86,6 +90,10 @@ class RedBearModelFactory:
"max_retries": config.max_retries,
**config.extra_params
}
# 流式模式下启用 stream_usage 以获取 token 统计
if config.extra_params.get("streaming"):
params["stream_usage"] = True
return params
elif provider == ModelProvider.DASHSCOPE:
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数

View File

@@ -99,7 +99,7 @@ class SimpleMCPClient:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
@@ -190,7 +190,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
@@ -205,7 +205,7 @@ class SimpleMCPClient:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
if not (200 <= response.status < 300):
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
@@ -223,7 +223,7 @@ class SimpleMCPClient:
try:
async with self._session.post(self.server_url, json=init_request) as response:
if response.status != 200:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")