From 876c39b1b0a8e4bed9a0f8339222839fb6027118 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 30 Mar 2026 18:37:09 +0800 Subject: [PATCH] fix(app): 1. Token consumption of the omni model; 2. Token consumption of the cluster includes sub-agents --- api/app/core/agent/langchain_agent.py | 42 +++++++++++++++----- api/app/core/models/base.py | 12 +++++- api/app/core/tools/mcp/client.py | 8 ++-- api/app/services/app_chat_service.py | 6 +-- api/app/services/master_agent_router.py | 11 +++++ api/app/services/multi_agent_orchestrator.py | 34 ++++++++++++++-- 6 files changed, 92 insertions(+), 21 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 464a668a..9776cc29 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -254,6 +254,33 @@ class LangChainAgent: return messages + @staticmethod + def _extract_tokens_from_message(msg) -> int: + """从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式 + + 支持的格式: + - response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI) + - response_metadata.usage.total_tokens (部分 provider) + - usage_metadata.total_tokens (LangChain 新版) + """ + total = 0 + # 1. response_metadata + response_meta = getattr(msg, "response_metadata", None) + if response_meta and isinstance(response_meta, dict): + # 尝试 token_usage 路径 + token_usage = response_meta.get("token_usage") or response_meta.get("usage", {}) + if isinstance(token_usage, dict): + total = token_usage.get("total_tokens", 0) + # 2. usage_metadata(LangChain 新版 AIMessage 属性) + if not total: + usage_meta = getattr(msg, "usage_metadata", None) + if usage_meta: + if isinstance(usage_meta, dict): + total = usage_meta.get("total_tokens", 0) + else: + total = getattr(usage_meta, "total_tokens", 0) + return total or 0 + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 构建多模态消息内容 @@ -412,8 +439,7 @@ class LangChainAgent: else: content = str(msg.content) logger.debug(f"转换为字符串: {content[:100]}...") - response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 + total_tokens = self._extract_tokens_from_message(msg) break logger.info(f"最终提取的内容长度: {len(content)}") @@ -458,7 +484,7 @@ class LangChainAgent: user_rag_memory_id: Optional[str] = None, memory_flag: Optional[bool] = True, files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[str | int, None]: """执行流式对话 Args: @@ -594,15 +620,13 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 + # 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages output_messages = event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): - response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get( - "total_tokens", - 0 - ) if response_meta else 0 - yield total_tokens + stream_total_tokens = self._extract_tokens_from_message(msg) + logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}") + yield stream_total_tokens break if memory_flag: await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 80117f27..a4dbc092 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -58,7 +58,7 @@ class RedBearModelFactory: write=60.0, pool=10.0, ) - return { + params = { "model": config.model_name, "base_url": config.base_url, "api_key": config.api_key, @@ -66,6 +66,10 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } + # 流式模式下启用 stream_usage 以获取 token 统计 + if config.extra_params.get("streaming"): + params["stream_usage"] = True + return params if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: # 使用 httpx.Timeout 对象来设置详细的超时配置 @@ -78,7 +82,7 @@ class RedBearModelFactory: write=60.0, # 写入超时:60秒 pool=10.0, # 连接池超时:10秒 ) - return { + params = { "model": config.model_name, "base_url": config.base_url, "api_key": config.api_key, @@ -86,6 +90,10 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } + # 流式模式下启用 stream_usage 以获取 token 统计 + if config.extra_params.get("streaming"): + params["stream_usage"] = True + return params elif provider == ModelProvider.DASHSCOPE: # DashScope (通义千问) 使用自己的参数格式 # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 6df6df51..3539d33a 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -99,7 +99,7 @@ class SimpleMCPClient: # 建立 SSE 连接 response = await self._session.get(self.server_url) - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") @@ -190,7 +190,7 @@ class SimpleMCPClient: try: async with self._session.post(self._endpoint_url, json=request) as response: - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") @@ -205,7 +205,7 @@ class SimpleMCPClient: raise MCPConnectionError("endpoint URL 未初始化") async with self._session.post(self._endpoint_url, json=notification) as response: - if response.status != 200: + if not (200 <= response.status < 300): logger.warning(f"通知发送失败: {response.status}") async def _initialize_modelscope_session(self): @@ -223,7 +223,7 @@ class SimpleMCPClient: try: async with self._session.post(self.server_url, json=init_request) as response: - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 90474428..b5f9f194 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -631,13 +631,13 @@ class AppChatService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - if "sub_usage" in event: + # 拦截 sub_usage 事件,累加 token + if "event: sub_usage" in event: if "data:" in event: try: data_line = event.split("data: ", 1)[1].strip() data = json.loads(data_line) - if "total_tokens" in data: - total_tokens += data["total_tokens"] + total_tokens += data.get("total_tokens", 0) except: pass else: diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index b0f43b51..954d3b2b 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -403,6 +403,17 @@ class MasterAgentRouter: response = await llm.ainvoke(prompt) ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取 token 消耗 + self._last_routing_tokens = 0 + if hasattr(response, 'usage_metadata') and response.usage_metadata: + um = response.usage_metadata + self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0) + elif hasattr(response, 'response_metadata') and response.response_metadata: + token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {}) + if isinstance(token_usage, dict): + self._last_routing_tokens = token_usage.get("total_tokens", 0) + logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}") + # 提取响应内容 if hasattr(response, 'content'): return response.content diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 60a3b5b8..1330caad 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -287,6 +287,11 @@ class MultiAgentOrchestrator: sub_conversation_id = None total_tokens = 0 + # 累加 Master Agent 路由决策消耗的 token + total_tokens += task_analysis.get("routing_tokens", 0) + # 累加 Master Agent 整合消耗的 token + total_tokens += getattr(self, '_last_merge_tokens', 0) + if isinstance(results, dict): sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") # 提取 token 信息 @@ -358,12 +363,16 @@ class MultiAgentOrchestrator: variables=variables ) + # 获取路由决策消耗的 token + routing_tokens = getattr(self.router, '_last_routing_tokens', 0) + logger.info( "Master Agent 分析完成", extra={ "selected_agent": routing_decision.get("selected_agent_id"), "confidence": routing_decision.get("confidence"), - "strategy": routing_decision.get("strategy") + "strategy": routing_decision.get("strategy"), + "routing_tokens": routing_tokens } ) @@ -372,7 +381,8 @@ class MultiAgentOrchestrator: "variables": variables or {}, "sub_agents": self.config.sub_agents, "initial_context": variables or {}, - "routing_decision": routing_decision + "routing_decision": routing_decision, + "routing_tokens": routing_tokens } async def _execute_sequential( @@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator: # 5. 流式执行子 Agent sub_conversation_id = None + # Master Agent 路由决策消耗的 token,通过 sub_usage 事件发送给上层 + routing_tokens = task_analysis.get("routing_tokens", 0) + if routing_tokens > 0: + yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens}) + async for event in self._execute_sub_agent_stream( agent_data["config"], message, @@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator: except: pass + # 直接透传所有事件(包括 sub_usage),累加统一由上层处理 yield event # 6. 如果有会话 ID,发送一个包含它的事件 @@ -2612,6 +2628,17 @@ class MultiAgentOrchestrator: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取整合消耗的 token + merge_tokens = 0 + if hasattr(response, 'usage_metadata') and response.usage_metadata: + um = response.usage_metadata + merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0) + elif hasattr(response, 'response_metadata') and response.response_metadata: + token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {}) + if isinstance(token_usage, dict): + merge_tokens = token_usage.get("total_tokens", 0) + self._last_merge_tokens = merge_tokens + # 提取响应内容 if hasattr(response, 'content'): merged_response = response.content @@ -2621,7 +2648,8 @@ class MultiAgentOrchestrator: logger.info( "Master Agent 整合完成", extra={ - "merged_length": len(merged_response) + "merged_length": len(merged_response), + "merge_tokens": merge_tokens } )