Merge pull request #277 from SuanmoSuanyangTechnology/fix/token

feat(app)
This commit is contained in:
Mark
2026-02-02 19:06:14 +08:00
committed by GitHub
5 changed files with 175 additions and 18 deletions

View File

@@ -427,7 +427,11 @@ class AppChatService:
meta_data={ meta_data={
"mode": result.get("mode"), "mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"), "elapsed_time": result.get("elapsed_time"),
"sub_results": result.get("sub_results") "usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
} }
) )
@@ -469,6 +473,7 @@ class AppChatService:
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
full_content = "" full_content = ""
total_tokens = 0
# 2. 创建编排器 # 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config) orchestrator = MultiAgentOrchestrator(self.db, config)
@@ -485,16 +490,26 @@ class AppChatService:
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event if "sub_usage" in event:
# 尝试提取内容(用于保存) if "data:" in event:
if "data:" in event: try:
try: data_line = event.split("data: ", 1)[1].strip()
data_line = event.split("data: ", 1)[1].strip() data = json.loads(data_line)
data = json.loads(data_line) if "total_tokens" in data:
if "content" in data: total_tokens += data["total_tokens"]
full_content += data["content"] except:
except: pass
pass else:
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -510,7 +525,12 @@ class AppChatService:
role="assistant", role="assistant",
content=full_content, content=full_content,
meta_data={ meta_data={
"elapsed_time": elapsed_time "elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
} }
) )

View File

@@ -678,6 +678,11 @@ class DraftRunService:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if sub_agent:
yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens
})
# 10. 保存会话消息 # 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message( await self._save_conversation_message(

View File

@@ -4,7 +4,7 @@ import uuid
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph, START, END
from langgraph.types import Command from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
@@ -727,9 +727,12 @@ class HandoffsService:
# 提取响应 # 提取响应
response_content = "" response_content = ""
total_tokens = 0
for msg in result.get("messages", []): for msg in result.get("messages", []):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_content = msg.content response_content = msg.content
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break break
return { return {
@@ -737,7 +740,12 @@ class HandoffsService:
"active_agent": result.get("active_agent"), "active_agent": result.get("active_agent"),
"response": response_content, "response": response_content,
"message_count": len(result.get("messages", [])), "message_count": len(result.get("messages", [])),
"handoff_count": result.get("handoff_count", 0) "handoff_count": result.get("handoff_count", 0),
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
} }
async def chat_stream( async def chat_stream(
@@ -830,6 +838,12 @@ class HandoffsService:
# 捕获 LLM 结束事件,输出收集到的工具调用 # 捕获 LLM 结束事件,输出收集到的工具调用
elif kind == "on_chat_model_end": elif kind == "on_chat_model_end":
output_message = event.get("data", {}).get("output", {})
if isinstance(output_message, AIMessageChunk):
response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n"
if collected_tool_calls: if collected_tool_calls:
# 找到参数最完整的 transfer 工具调用 # 找到参数最完整的 transfer 工具调用
best_tc = None best_tc = None

View File

@@ -280,14 +280,22 @@ class MultiAgentOrchestrator:
# 4. 提取子 Agent 的 conversation_id用于多轮对话 # 4. 提取子 Agent 的 conversation_id用于多轮对话
sub_conversation_id = None sub_conversation_id = None
total_tokens = 0
if isinstance(results, dict): if isinstance(results, dict):
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
# 提取 token 信息
usage = results.get("usage", {}) or results.get("result", {}).get("usage", {})
total_tokens += usage.get("total_tokens", 0)
elif isinstance(results, list) and results: elif isinstance(results, list) and results:
for item in results: for item in results:
if "result" in item: if "result" in item:
sub_conversation_id = item["result"].get("conversation_id") sub_conversation_id = item["result"].get("conversation_id")
if sub_conversation_id: if sub_conversation_id:
break break
# 累加每个子 Agent 的 token
usage = item.get("usage", {}) or item.get("result", {}).get("usage", {})
total_tokens += usage.get("total_tokens", 0)
logger.info( logger.info(
"多 Agent 任务完成", "多 Agent 任务完成",
@@ -301,9 +309,15 @@ class MultiAgentOrchestrator:
return { return {
"message": final_result, "message": final_result,
"conversation_id": sub_conversation_id, "conversation_id": sub_conversation_id,
"mode": OrchestrationMode.SUPERVISOR,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"strategy": routing_decision.get("collaboration_strategy", "single"), "strategy": routing_decision.get("collaboration_strategy", "single"),
"sub_results": results "sub_results": results,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
} }
except Exception as e: except Exception as e:
@@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator:
return { return {
"message": result.get("response", ""), "message": result.get("response", ""),
"conversation_id": result.get("conversation_id"), "conversation_id": result.get("conversation_id"),
"mode": OrchestrationMode.COLLABORATION,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"strategy": "collaboration", "strategy": "collaboration",
"active_agent": result.get("active_agent"), "active_agent": result.get("active_agent"),
"sub_results": result "sub_results": result,
"usage": result.get("usage")
} }
except Exception as e: except Exception as e:

View File

@@ -1,5 +1,6 @@
"""多 Agent 配置管理服务""" """多 Agent 配置管理服务"""
import uuid import uuid
import json
from typing import Optional, List, Tuple, Any, Annotated from typing import Optional, List, Tuple, Any, Annotated
from fastapi import Depends from fastapi import Depends
@@ -427,6 +428,23 @@ class MultiAgentService:
memory=getattr(request, 'memory', True) # 记忆功能参数 memory=getattr(request, 'memory', True) # 记忆功能参数
) )
await self._save_conversation_message(
conversation_id=request.conversation_id,
user_message=request.message,
assistant_message=result.get("message", ""),
app_id=app_id,
user_id=request.user_id,
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
return result return result
async def run_stream( async def run_stream(
@@ -451,11 +469,14 @@ class MultiAgentService:
raise ResourceNotFoundException("多 Agent 配置", str(app_id)) raise ResourceNotFoundException("多 Agent 配置", str(app_id))
if not config.is_active: if not config.is_active:
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED) raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND)
# 2. 创建编排器 # 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config) orchestrator = MultiAgentOrchestrator(self.db, config)
full_content = ""
total_tokens = 0
# 3. 流式执行任务 # 3. 流式执行任务
async for event in orchestrator.execute_stream( async for event in orchestrator.execute_stream(
message=request.message, message=request.message,
@@ -468,7 +489,88 @@ class MultiAgentService:
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event if "sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
except:
pass
else:
yield event
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
await self._save_conversation_message(
conversation_id=request.conversation_id,
user_message=request.message,
assistant_message=full_content,
app_id=app_id,
user_id=request.user_id,
meta_data={
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
)
async def _save_conversation_message(
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None
) -> None:
"""保存会话消息
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复消息
meta_data: 元数据(包括 token 消耗)
app_id: 应用ID
user_id: 用户ID
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=user_message
)
conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message,
meta_data=meta_data
)
logger.debug(
"保存多 Agent 会话消息",
extra={
"conversation_id": conversation_id,
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
except Exception as e:
logger.warning("保存会话消息失败", extra={"error": str(e)})
# def add_sub_agent( # def add_sub_agent(
# self, # self,