diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 26abd0f9..bd9106e5 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -427,7 +427,11 @@ class AppChatService: meta_data={ "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), - "sub_results": result.get("sub_results") + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) @@ -469,6 +473,7 @@ class AppChatService: yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" full_content = "" + total_tokens = 0 # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) @@ -485,16 +490,26 @@ class AppChatService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - yield event - # 尝试提取内容(用于保存) - if "data:" in event: - try: - data_line = event.split("data: ", 1)[1].strip() - data = json.loads(data_line) - if "content" in data: - full_content += data["content"] - except: - pass + if "sub_usage" in event: + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "total_tokens" in data: + total_tokens += data["total_tokens"] + except: + pass + else: + yield event + # 尝试提取内容(用于保存) + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass elapsed_time = time.time() - start_time @@ -510,7 +525,12 @@ class AppChatService: role="assistant", content=full_content, meta_data={ - "elapsed_time": elapsed_time + "elapsed_time": elapsed_time, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index dc01e541..9a3e1d37 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -678,6 +678,11 @@ class DraftRunService: elapsed_time = time.time() - start_time + if sub_agent: + yield self._format_sse_event("sub_usage", { + "total_tokens": total_tokens + }) + # 10. 保存会话消息 if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): await self._save_conversation_message( diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index 114e9945..10e4d646 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -4,7 +4,7 @@ import uuid from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated from typing_extensions import TypedDict -from langchain_core.messages import HumanMessage, AIMessage, BaseMessage +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk from langgraph.graph import StateGraph, START, END from langgraph.types import Command from langgraph.checkpoint.memory import MemorySaver @@ -727,9 +727,12 @@ class HandoffsService: # 提取响应 response_content = "" + total_tokens = 0 for msg in result.get("messages", []): if isinstance(msg, AIMessage): response_content = msg.content + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break return { @@ -737,7 +740,12 @@ class HandoffsService: "active_agent": result.get("active_agent"), "response": response_content, "message_count": len(result.get("messages", [])), - "handoff_count": result.get("handoff_count", 0) + "handoff_count": result.get("handoff_count", 0), + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } async def chat_stream( @@ -830,6 +838,12 @@ class HandoffsService: # 捕获 LLM 结束事件,输出收集到的工具调用 elif kind == "on_chat_model_end": + output_message = event.get("data", {}).get("output", {}) + if isinstance(output_message, AIMessageChunk): + response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", + 0) if response_meta else 0 + yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n" if collected_tool_calls: # 找到参数最完整的 transfer 工具调用 best_tc = None diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index d9062eaf..b28bafbf 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -280,14 +280,22 @@ class MultiAgentOrchestrator: # 4. 提取子 Agent 的 conversation_id(用于多轮对话) sub_conversation_id = None + total_tokens = 0 + if isinstance(results, dict): sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") + # 提取 token 信息 + usage = results.get("usage", {}) or results.get("result", {}).get("usage", {}) + total_tokens += usage.get("total_tokens", 0) elif isinstance(results, list) and results: for item in results: if "result" in item: sub_conversation_id = item["result"].get("conversation_id") if sub_conversation_id: break + # 累加每个子 Agent 的 token + usage = item.get("usage", {}) or item.get("result", {}).get("usage", {}) + total_tokens += usage.get("total_tokens", 0) logger.info( "多 Agent 任务完成", @@ -301,9 +309,15 @@ class MultiAgentOrchestrator: return { "message": final_result, "conversation_id": sub_conversation_id, + "mode": OrchestrationMode.SUPERVISOR, "elapsed_time": elapsed_time, "strategy": routing_decision.get("collaboration_strategy", "single"), - "sub_results": results + "sub_results": results, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } except Exception as e: @@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator: return { "message": result.get("response", ""), "conversation_id": result.get("conversation_id"), + "mode": OrchestrationMode.COLLABORATION, "elapsed_time": elapsed_time, "strategy": "collaboration", "active_agent": result.get("active_agent"), - "sub_results": result + "sub_results": result, + "usage": result.get("usage") } except Exception as e: diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index da984d16..c52814ed 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -1,5 +1,6 @@ """多 Agent 配置管理服务""" import uuid +import json from typing import Optional, List, Tuple, Any, Annotated from fastapi import Depends @@ -427,6 +428,23 @@ class MultiAgentService: memory=getattr(request, 'memory', True) # 记忆功能参数 ) + await self._save_conversation_message( + conversation_id=request.conversation_id, + user_message=request.message, + assistant_message=result.get("message", ""), + app_id=app_id, + user_id=request.user_id, + meta_data={ + "mode": result.get("mode"), + "elapsed_time": result.get("elapsed_time"), + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } + ) + return result async def run_stream( @@ -451,11 +469,14 @@ class MultiAgentService: raise ResourceNotFoundException("多 Agent 配置", str(app_id)) if not config.is_active: - raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED) + raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND) # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) + full_content = "" + total_tokens = 0 + # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=request.message, @@ -468,7 +489,88 @@ class MultiAgentService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - yield event + if "sub_usage" in event: + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "total_tokens" in data: + total_tokens += data["total_tokens"] + except: + pass + else: + yield event + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass + + await self._save_conversation_message( + conversation_id=request.conversation_id, + user_message=request.message, + assistant_message=full_content, + app_id=app_id, + user_id=request.user_id, + meta_data={ + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } + } + ) + + async def _save_conversation_message( + self, + conversation_id: uuid.UUID, + user_message: str, + assistant_message: str, + meta_data: dict, + app_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None + ) -> None: + """保存会话消息 + + Args: + conversation_id: 会话ID + user_message: 用户消息 + assistant_message: AI 回复消息 + meta_data: 元数据(包括 token 消耗) + app_id: 应用ID + user_id: 用户ID + """ + try: + from app.services.conversation_service import ConversationService + + conversation_service = ConversationService(self.db) + + conversation_service.add_message( + conversation_id=conversation_id, + role="user", + content=user_message + ) + conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=assistant_message, + meta_data=meta_data + ) + + logger.debug( + "保存多 Agent 会话消息", + extra={ + "conversation_id": conversation_id, + "user_message_length": len(user_message), + "assistant_message_length": len(assistant_message) + } + ) + + except Exception as e: + logger.warning("保存会话消息失败", extra={"error": str(e)}) # def add_sub_agent( # self,