Files
MemoryBear/api/app/services/draft_run_service.py
Ke Sun 79ab929fb0 Release/v0.2.3 (#355)
* feat(web): add PageEmpty component

* feat(web): add PageTabs component

* feat(web): add PageEmpty component

* feat(web): add PageTabs component

* feat(prompt): add history tracking for prompt releases

* feat(web): add prompt menu

* refactor: The PageScrollList component supports two generic parameters

* feat(web): BodyWrapper compoent update PageLoading

* feat(web): add Ontology menu

* feat(web): memory management add scene

* feat(tasks): add celery task configuration for periodic jobs

- Add ignore_result=True to prevent storing results for periodic tasks
- Set max_retries=0 to skip failed periodic tasks without retry attempts
- Configure acks_late=False for immediate acknowledgment in beat tasks
- Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s)
- Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s)
- Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s)
- Improve task reliability and resource management for scheduled jobs

* feat(sandbox): add Node.js code execution support to sandbox

* Release/v0.2.2 (#260)

* [modify] migration script

* [add] migration script

* fix(web): change form message

* fix(web): the memoryContent field is compatible with numbers and strings

* feat(web): code node hidden

* fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.

---------

Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>

* Feature/ontology class clean (#249)

* [add] Complete ontology engineering feature implementation

* [add] Add ontology feature integration and validation utilities

* [add] Add OWL validator and validation utilities

* [fix] Add missing render_ontology_extraction_prompt function

* [fix]Add dependencies, fix functionality

* [add] migration script

* feat(celery): add dedicated periodic tasks worker and queue (#261)

* fix(web): conflict resolve

* Fix/v022 bug (#263)

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [add]Add scene_id

* [fix]Based on the AI review to fix the code

* Fix/develop memory reflex (#265)

* 遗漏的历史映射

* 遗漏的历史映射

* 反思后台报错处理

* [add] migration script

* fix: chat conversation_id add node_start

* feat(web): show code node

* fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated

* feat(web): RadioGroupCard support block mode

* feat(web): create space add icon

* feat(app and model): token consumption statistics

* Add/develop memory (#264)

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 新增长期记忆功能

* 新增长期记忆功能

* 新增长期记忆功能

* 知识库检索多余字段

* 长期

* feat(app and model): token consumption statistics of the cluster

* memory_BUG_fix

* fix(web): prompt history remove pageLoading

* fix(prompt): remove hard-coded import of prompt file paths (#279)

* Fix/develop memory bug (#274)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix(web): update retrieve_type key

* Fix/develop memory bug (#276)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix_timeline_memories

* write_gragp/bug_fix

* write_gragp/bug_fix

* write_gragp/bug_fix

* chore(celery): disable periodic task scheduling

* fix(prompt): remove hard-coded import of prompt file paths

---------

Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>

* fix(web): remove delete confirm content

* refactor(workflow): relocate template directory into workflow

* feat(memory): add long-term storage task routing and batching

* fix(web): PageScrollList loading update

* fix(web): PageScrollList loading update

* Ontology v1 bug (#291)

* [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]Set the default values of the memory configuration model based on the spatial model.

* [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]

* fix(memory): add Redis session validation

- Add macOS fork() safety configuration in celery_app.py to prevent initialization issues
- Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully
- Add null/False checks in memory_long_term_storage to prevent processing empty Redis results
- Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data
- Initialize redis_messages variable in window_dialogue for consistency
- Add debug logging when no existing session found in Redis for better troubleshooting
- Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants
- Improve error handling when Redis returns False or empty results instead of crashing

* fix(web): PageScrollList style update

* fix(workflow): fix argument passing in code execution nodes

* fix(web): prompt add disabled

* fix(web): space icon required

* feat(app): modify the key of the token

* fix(fix the key of the app's token):

* fix(workflow): switch code input encoding to base64+URL encoding

* [add]The main project adds multi-API Key load balancing.

* [changes]Attribute security access, secure numerical conversion, unified use of local variables

* fix(web): save add session update

* fix(web): language editor support paste

* [changes]Active status filtering logic, API Key selection strategy

* memory_BUG

* memory_BUG_long_term

* [changes]

* memory_BUG_long_term

* memory_BUG_long_term

* Fix/release memory bug (#306)

* memory_BUG_fix

* memory_BUG

* memory_BUG_long_term

* memory_BUG_long_term

* memory_BUG_long_term

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios

* fix(web): replace code editor

* [changes]Modify the description of the time for the recent event

* [changes]Modify the code based on the AI review

* feat(web): update memory config ontology api

* fix(web): ui update

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* feat(workflow): add token usage statistics for question classifier and parameter extraction

* feat(web): move prompt menu

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Write Missing None (#321)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/release memory bug (#324)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/writer memory bug (#326)

* [fix]Fix the bug

* [fix]Fix the bug

* [fix]Correct the direction indication.

* fix(web): markdown table ui update

* Fix/release memory bug (#332)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/fact summary (#333)

* [fix]Disable the contents related to fact_summary

* [fix]Disable the contents related to fact_summary

* [fix]Modify the code based on the AI review

* Fix/release memory bug (#335)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

* writer_graph_bug/fix

* writer_graph_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Revert "feat(web): move prompt menu"

This reverts commit 9e6e8f50f8.

* fix(web): ui update

* fix(web): update text

* fix(web): ui update

* fix(model): change the "vl" model type of dashscope to "chat"

* fix(model): change the "vl" model type of dashscope to "chat"

---------

Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: Eternity <1533512157@qq.com>
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
Co-authored-by: 乐力齐 <162269739+lanceyq@users.noreply.github.com>
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: lixinyue <2569494688@qq.com>
Co-authored-by: Eternity <61316157+myhMARS@users.noreply.github.com>
Co-authored-by: lanceyq <1982376970@qq.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-02-06 19:01:57 +08:00

1630 lines
61 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
试运行服务
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
"""
import asyncio
import datetime
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.repositories.model_repository import ModelApiKeyRepository
from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
"""知识库检索工具输入参数"""
query: str = Field(description="需要检索的问题或关键词")
class WebSearchInput(BaseModel):
"""网络搜索工具输入参数"""
query: str = Field(description="需要搜索的问题或关键词")
class LongTermMemoryInput(BaseModel):
"""长期记忆工具输入参数"""
question: str = Field(description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,user_rag_memory_id: Optional[str] = None):
"""创建记忆工具,
Args:
memory_config: 记忆配置
end_user_id: 用户ID
storage_type: 存储类型(可选)
Returns:
长期记忆工具
"""
# search_switch = memory_config.get("search_switch", "2")
config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None)
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
@tool(args_schema=LongTermMemoryInput)
def long_term_memory(question: str) -> str:
"""
从用户的历史记忆中检索相关信息。这是一个强大的工具,可以帮助你了解用户的背景、偏好和历史对话内容。
以下场景不需要使用此工具:
1. 情绪/社交问候场景(如"你好""谢谢""再见"等简单寒暄)
2. 纯任务性场景(如"帮我写代码""翻译这段文字"等不需要历史上下文的任务)
3. 处理外部内容时如用户提供的文本、代码、RAG数据等这些内容本身已经包含所需信息
除上述场景外的所有其他情况都应该使用此工具,特别是:
- 用户询问个人信息或历史对话内容
- 需要了解用户偏好、习惯或背景
- 用户提到"之前""上次""记得"等涉及历史的词汇
- 需要个性化回复或基于历史上下文的建议
- 用户询问关于自己的任何信息
需要对question改写/优化:
需要重点关注一以下几点
- 相关的关键词,保持原问题的核心语义不变, 根据上下文,使问题更具体、更清晰,将模糊的表达转换为明确的搜索词
- 使用同义词或相关术语扩展查询
Args:
question: question改写之后的内容
Returns:
检索到的历史记忆内容
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
from app.db import get_db
db = next(get_db())
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
end_user_id=end_user_id,
message=question,
history=[],
search_switch="2",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
)
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
)
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
logger.info(
"长期记忆检索成功",
extra={
"end_user_id": end_user_id,
"content_length": len(str(memory_content))
}
)
return f"检索到以下历史记忆:\n\n{memory_content}"
except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"记忆检索失败: {str(e)}"
return long_term_memory
def create_web_search_tool(web_search_config: Dict[str, Any]):
"""创建网络搜索工具
Args:
web_search_config: 网络搜索配置
Returns:
网络搜索工具
"""
logger.info("创建网络搜索工具")
@tool(args_schema=WebSearchInput)
def web_search_tool(query: str) -> str:
"""从互联网搜索最新信息。当用户的问题需要实时信息、最新新闻或网络资料时,使用此工具进行搜索。
Args:
query: 需要搜索的问题或关键词
Returns:
搜索到的相关网络信息
"""
try:
logger.info(f"执行网络搜索: {query}")
# 调用搜索服务
search_result = Search(query)
logger.info(
"网络搜索成功",
extra={
"query": query,
"result_length": len(search_result)
}
)
return f"搜索到以下网络信息:\n\n{search_result}"
except Exception as e:
logger.error("网络搜索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"搜索失败: {str(e)}"
return web_search_tool
def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
query: 需要检索的问题或关键词
Returns:
检索到的相关知识内容
"""
logger.info(f"创建知识库检索工具,用户:{user_id}")
@tool(args_schema=KnowledgeRetrievalInput)
def knowledge_retrieval_tool(query: str) -> str:
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
query: 需要检索的问题或关键词
Returns:
检索到的相关知识内容
"""
try:
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
if retrieve_chunks_result:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
context = '\n\n'.join(retrieval_knowledge)
logger.info(
"知识库检索成功",
extra={
"kb_ids": kb_ids,
"result_count": len(retrieval_knowledge),
"total_length": len(context)
}
)
return f"检索到以下相关信息:\n\n{context}"
else:
logger.warning("知识库检索未找到结果")
return "未找到相关信息"
except Exception as e:
logger.error("知识库检索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"检索失败: {str(e)}"
return knowledge_retrieval_tool
class DraftRunService:
"""试运行服务类"""
def __init__(self, db: Session):
"""初始化试运行服务
Args:
db: 数据库会话
"""
self.db = db
async def run(
self,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True,
sub_agent: bool = False
) -> Dict[str, Any]:
"""执行试运行(使用 LangChain Agent
Args:
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
workspace_id: 工作空间ID必须用于会话隔离
conversation_id: 会话ID用于多轮对话
user_id: 用户ID
variables: 自定义变量参数值
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
memory_flag=False
print('===========',storage_type)
print(user_id)
if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent
start_time = time.time()
try:
# 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id)
logger.debug(
"API Key 配置获取成功",
extra={
"model_name": api_key_config["model_name"],
"has_api_key": bool(api_key_config["api_key"]),
"has_api_base": bool(api_key_config.get("api_base"))
}
)
# 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters(
model_config=model_config,
agent_config=agent_config
)
items_params=variables
system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误
PromptMessageRole.USER,
items_params
)
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
print('系统提示词:',system_prompt)
# 4. 准备工具列表
tools = []
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
print("+"*50)
print(f"agent_config:{agent_config}")
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config,kb_ids,user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
# 添加长期记忆工具
if memory:
if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag=True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
)
# 5. 处理会话ID创建或验证
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id
)
# 6. 加载历史消息
history = []
if agent_config.memory and agent_config.memory.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
)
# 6. 知识库检索
context = None
logger.debug(
"准备调用 LangChain Agent",
extra={
"model": api_key_config["model_name"],
"has_history": bool(history),
"has_context": bool(context)
}
)
memory_config_= agent_config.memory
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
# 7. 调用 Agent
result = await agent.chat(
message=message,
history=history,
context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag
)
elapsed_time = time.time() - start_time
# 8. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
assistant_message=result["content"],
app_id=agent_config.app_id,
user_id=user_id,
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
response = {
"message": result["content"],
"conversation_id": conversation_id,
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
}
logger.info(
"试运行完成",
extra={
"model": model_config.name,
"elapsed_time": elapsed_time,
"message_length": len(result["content"]),
"total_tokens": result.get("usage", {}).get("total_tokens", 0)
}
)
return response
except Exception as e:
logger.error("LangChain Agent 调用失败", extra={"error": str(e), "error_type": type(e).__name__})
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
async def run_stream(
self,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True, # 布尔类型默认值
memory: bool = True, # 布尔类型默认值
sub_agent: bool = False # 是否是作为子Agent运行
) -> AsyncGenerator[str, None]:
"""执行试运行(流式返回,使用 LangChain Agent
Args:
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
workspace_id: 工作空间ID必须用于会话隔离
conversation_id: 会话ID用于多轮对话
user_id: 用户ID
variables: 自定义变量参数值
Yields:
str: SSE 格式的事件数据
"""
memory_flag=False
if variables==None:variables={}
from app.core.agent.langchain_agent import LangChainAgent
start_time = time.time()
try:
# 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id)
# 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters(
model_config=model_config,
agent_config=agent_config
)
items_params=variables
system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误
PromptMessageRole.USER,
items_params
)
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# 4. 准备工具列表
tools = []
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
for tool_config in agent_config.tools:
# print("+"*50)
# print(f"agent_config:{agent_config}")
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
# 添加长期记忆工具
if memory:
if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag= True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
)
# 5. 处理会话ID创建或验证
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id
)
# 6. 加载历史消息
history = []
if agent_config.memory and agent_config.memory.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
)
# 7. 知识库检索
context = None
# 8. 发送开始事件
yield self._format_sse_event("start", {
"conversation_id": conversation_id,
"timestamp": time.time()
})
memory_config_ = agent_config.memory
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
# 9. 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag
):
if isinstance(chunk, int):
total_tokens = chunk
else:
full_content += chunk
# 发送消息块事件
yield self._format_sse_event("message", {
"content": chunk
})
elapsed_time = time.time() - start_time
if sub_agent:
yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens
})
# 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
assistant_message=full_content,
app_id=agent_config.app_id,
user_id=user_id,
meta_data={
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}
)
# 11. 发送结束事件
yield self._format_sse_event("end", {
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"message_length": len(full_content)
})
logger.info(
"流式试运行完成",
extra={
"model": model_config.name,
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except Exception as e:
logger.error("流式 Agent 调用失败", extra={"error": str(e)}, exc_info=True)
# 发送错误事件
yield self._format_sse_event("error", {
"error": str(e),
"timestamp": time.time()
})
def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str:
"""格式化 SSE 事件
Args:
event: 事件类型
data: 事件数据
Returns:
str: SSE 格式的字符串
"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
"""获取模型的 API Key
Args:
model_config_id: 模型配置ID
Returns:
Dict: 包含 model_name, api_key, api_base 的字典
Raises:
BusinessException: 当没有可用的 API Key 时
"""
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# stmt = (
# select(ModelApiKey).join(
# ModelConfig, ModelApiKey.model_configs
# )
# .where(
# ModelConfig.id == model_config_id,
# ModelApiKey.is_active.is_(True)
# )
# .order_by(ModelApiKey.priority.desc())
# .limit(1)
# )
#
# api_key = self.db.scalars(stmt).first()
api_key = api_keys[0] if api_keys else None
if not api_key:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
return {
"model_name": api_key.model_name,
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base
}
async def _ensure_conversation(
self,
conversation_id: Optional[str],
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str]
) -> str:
"""确保会话存在(创建或验证)
Args:
conversation_id: 会话ID可选
app_id: 应用ID
workspace_id: 工作空间ID必须
user_id: 用户ID
Returns:
str: 会话ID
Raises:
BusinessException: 当指定的会话不存在时
"""
from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
# 如果没有提供会话ID创建新会话
if not conversation_id:
logger.info(
"创建新的草稿会话",
extra={"workspace_id": str(workspace_id)}
)
# 获取配置快照
config_snapshot = await self._get_config_snapshot(app_id)
# 创建新会话
new_conv_id = str(uuid.uuid4())
new_conversation = ConversationModel(
id=uuid.UUID(new_conv_id),
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
is_draft=True,
title="草稿会话",
config_snapshot=config_snapshot
)
self.db.add(new_conversation)
self.db.commit()
self.db.refresh(new_conversation)
logger.info(
"创建草稿会话成功",
extra={
"conversation_id": new_conv_id,
"workspace_id": str(workspace_id)
}
)
return new_conv_id
# 如果提供了会话ID验证其存在性和工作空间归属
try:
conv_uuid = uuid.UUID(conversation_id)
conversation = conversation_service.get_conversation(conv_uuid)
# 验证会话属于当前工作空间
if conversation.workspace_id != workspace_id:
logger.warning(
"会话不属于当前工作空间",
extra={
"conversation_id": conversation_id,
"conversation_workspace_id": str(conversation.workspace_id),
"current_workspace_id": str(workspace_id)
}
)
raise BusinessException(
"会话不属于当前工作空间",
BizCode.PERMISSION_DENIED
)
logger.debug(
"使用现有会话",
extra={
"conversation_id": conversation_id,
"workspace_id": str(workspace_id)
}
)
return conversation_id
except BusinessException:
raise
except Exception as e:
logger.error(
"会话不存在或无效",
extra={"conversation_id": conversation_id, "error": str(e)}
)
raise BusinessException(
f"会话不存在: {conversation_id}",
BizCode.NOT_FOUND,
cause=e
)
async def _load_conversation_history(
self,
conversation_id: str,
max_history: int = 10
) -> List[Dict[str, str]]:
"""加载会话历史消息
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
Returns:
List[Dict]: 历史消息列表
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
history = conversation_service.get_conversation_history(
conversation_id=uuid.UUID(conversation_id),
max_history=max_history
)
logger.debug(
"加载会话历史",
extra={
"conversation_id": conversation_id,
"max_history": max_history,
"loaded_count": len(history)
}
)
return history
except Exception as e:
# 新会话没有历史记录是正常的
logger.debug("加载会话历史失败(可能是新会话)", extra={"error": str(e)})
return []
async def _save_conversation_message(
self,
conversation_id: str,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None
) -> None:
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复消息
app_id: 应用ID未使用保留用于兼容性
user_id: 用户ID未使用保留用于兼容性
meta_data: token消耗
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
conv_uuid = uuid.UUID(conversation_id)
# 保存消息(会话已经存在)
# 保存用户消息
conversation_service.add_message(
conversation_id=conv_uuid,
role="user",
content=user_message
)
# 保存助手消息
conversation_service.add_message(
conversation_id=conv_uuid,
role="assistant",
content=assistant_message,
meta_data=meta_data
)
logger.debug(
"保存会话消息",
extra={
"conversation_id": conversation_id,
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
except Exception as e:
logger.warning("保存会话消息失败", extra={"error": str(e)})
async def _get_config_snapshot(self, app_id: uuid.UUID) -> Dict[str, Any]:
"""获取当前配置快照
Args:
app_id: 应用ID
Returns:
Dict: 配置快照
"""
try:
from app.models import AgentConfig, ModelConfig
# 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
return {}
# 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
# 构建快照(确保所有值都可序列化)
def safe_serialize(value):
"""安全序列化值"""
if value is None:
return None
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (dict, list)):
return value
# 对于 Pydantic 模型或其他对象,尝试转换为字典
if hasattr(value, 'dict'):
return value.dict()
if hasattr(value, '__dict__'):
return value.__dict__
return str(value)
snapshot = {
"agent_config": {
"system_prompt": agent_cfg.system_prompt,
"model_parameters": safe_serialize(agent_cfg.model_parameters),
"knowledge_retrieval": safe_serialize(agent_cfg.knowledge_retrieval),
"memory": safe_serialize(agent_cfg.memory),
"variables": safe_serialize(agent_cfg.variables),
"tools": safe_serialize(agent_cfg.tools)
},
"model_config": {
"model_name": model_config.name if model_config else None,
"provider": model_config.provider if model_config else None,
"type": model_config.type if model_config else None
} if model_config else None,
"snapshot_time": datetime.datetime.now().isoformat()
}
return snapshot
except Exception as e:
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
return {}
def _replace_variables(
self,
text: str,
values: Dict[str, Any],
definitions: List[Dict[str, Any]]
) -> str:
"""替换文本中的变量
Args:
text: 原始文本
values: 变量值
definitions: 变量定义
Returns:
str: 替换后的文本
"""
result = text
# 创建变量定义映射
var_defs = {var["name"]: var for var in definitions}
for var_name, var_value in values.items():
# 检查变量是否在定义中
if var_name not in var_defs:
logger.warning(f"未定义的变量: {var_name}")
continue
# 替换变量(支持多种格式)
placeholders = [
f"{{{{{var_name}}}}}", # {{var_name}}
f"{{{var_name}}}", # {var_name}
f"${{{var_name}}}", # ${var_name}
]
for placeholder in placeholders:
if placeholder in result:
result = result.replace(placeholder, str(var_value))
return result
# ==================== 多模型对比试运行 ====================
async def run_compare(
self,
*,
agent_config: AgentConfig,
models: List[Dict[str, Any]],
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
parallel: bool = True,
timeout: int = 60,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True,
) -> Dict[str, Any]:
"""多模型对比试运行
Args:
agent_config: Agent 配置
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
message: 用户消息
workspace_id: 工作空间ID
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
parallel: 是否并行执行
timeout: 超时时间(秒)
Returns:
Dict: 对比结果
"""
logger.info(
"多模型对比试运行",
extra={
"model_count": len(models),
"parallel": parallel
}
)
async def run_single_model(model_info):
"""运行单个模型"""
try:
start_time = time.time()
# 临时修改参数(不使用 deepcopy 避免 SQLAlchemy 会话问题)
original_params = agent_config.model_parameters
agent_config.model_parameters = model_info["parameters"]
# 使用模型自己的 conversation_id如果没有则使用全局的
model_conversation_id = model_info.get("conversation_id") or conversation_id
try:
result = await asyncio.wait_for(
self.run(
agent_config=agent_config,
model_config=model_info["model_config"],
message=message,
workspace_id=workspace_id,
conversation_id=model_conversation_id,
user_id=user_id,
variables=variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=web_search,
memory=memory
),
timeout=timeout
)
finally:
# 恢复原始参数
agent_config.model_parameters = original_params
elapsed = time.time() - start_time
usage = result.get("usage", {})
return {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_info["label"],
"conversation_id":result['conversation_id'],
"parameters_used": model_info["parameters"],
"message": result.get("message"),
"usage": usage,
"elapsed_time": elapsed,
"tokens_per_second": (
usage.get("completion_tokens", 0) / elapsed
if elapsed > 0 and usage.get("completion_tokens") else None
),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"error": None
}
except TimeoutError:
logger.warning(
"模型运行超时",
extra={
"model_config_id": str(model_info["model_config_id"]),
"timeout": timeout
}
)
return {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"conversation_id": conversation_id,
"label": model_info["label"],
"parameters_used": model_info["parameters"],
"elapsed_time": timeout,
"error": f"执行超时({timeout}秒)"
}
except Exception as e:
logger.error(
"模型运行失败",
extra={
"model_config_id": str(model_info["model_config_id"]),
"error": str(e)
}
)
return {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_info["label"],
"conversation_id": conversation_id,
"parameters_used": model_info["parameters"],
"elapsed_time": 0,
"error": str(e)
}
# 执行所有模型(并行或串行)
if parallel:
logger.debug(f"并行执行 {len(models)} 个模型")
results = await asyncio.gather(
*[run_single_model(m) for m in models],
return_exceptions=False
)
else:
logger.debug(f"串行执行 {len(models)} 个模型")
results = []
for model_info in models:
result = await run_single_model(model_info)
results.append(result)
# 统计分析
successful = [r for r in results if not r.get("error")]
failed = [r for r in results if r.get("error")]
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
cheapest = min(
successful,
key=lambda x: x.get("cost_estimate") or float("inf")
) if successful else None
logger.info(
"多模型对比完成",
extra={
"successful": len(successful),
"failed": len(failed),
"total_time": sum(r.get("elapsed_time", 0) for r in results)
}
)
return {
"results": results,
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful),
"failed_count": len(failed),
"fastest_model": fastest["label"] if fastest else None,
"cheapest_model": cheapest["label"] if cheapest else None
}
def _estimate_cost(self, usage: Dict[str, Any], model_config) -> Optional[float]:
"""估算成本
Args:
usage: Token 使用情况
model_config: 模型配置
Returns:
Optional[float]: 估算成本(美元)
"""
if not usage:
return None
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
# 简化成本估算:暂时返回 None
# TODO: 实现基于模型名称或配置的成本估算
# 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段
return None
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> AgentConfig:
"""创建一个带有覆盖参数的 agent_config浅拷贝只修改 model_parameters
Args:
agent_config: 原始 Agent 配置
parameters: 要覆盖的参数
Returns:
AgentConfig: 修改后的配置(注意:这是同一个对象,只是临时修改了 model_parameters
"""
# 保存原始参数
original_params = agent_config.model_parameters
# 设置新参数
agent_config.model_parameters = parameters
return agent_config, original_params
async def run_compare_stream(
self,
*,
agent_config: AgentConfig,
models: List[Dict[str, Any]],
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True,
parallel: bool = True,
timeout: int = 60
) -> AsyncGenerator[str, None]:
"""多模型对比试运行(流式返回)
参考 run_compare 的实现,支持并行或串行执行
Args:
agent_config: Agent 配置
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
message: 用户消息
workspace_id: 工作空间ID
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
storage_type: 存储类型
user_rag_memory_id: RAG 记忆 ID
web_search: 是否启用网络搜索
memory: 是否启用记忆
parallel: 是否并行执行
timeout: 超时时间(秒)
Yields:
str: SSE 格式的事件数据
"""
logger.info(
"多模型对比流式试运行",
extra={"model_count": len(models), "parallel": parallel}
)
# 发送开始事件
yield self._format_sse_event("compare_start", {
"conversation_id": conversation_id,
"model_count": len(models),
"parallel": parallel,
"timestamp": time.time()
})
results = []
async def run_single_model_stream(idx: int, model_info: Dict[str, Any], event_queue: asyncio.Queue):
"""运行单个模型(流式)并将事件放入队列"""
model_label = model_info["label"]
model_config_id = str(model_info["model_config_id"])
# 使用模型自己的 conversation_id如果没有则使用全局的
model_conversation_id = model_info.get("conversation_id") or conversation_id
try:
# 发送模型开始事件
await event_queue.put(self._format_sse_event("model_start", {
"model_index": idx,
"model_config_id": model_config_id,
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": model_conversation_id,
"timestamp": time.time()
}))
start_time = time.time()
full_content = ""
returned_conversation_id = model_conversation_id
# 临时修改参数
original_params = agent_config.model_parameters
agent_config.model_parameters = model_info["parameters"]
try:
# 流式调用单个模型
async for event_str in self.run_stream(
agent_config=agent_config,
model_config=model_info["model_config"],
message=message,
workspace_id=workspace_id,
conversation_id=model_conversation_id,
user_id=user_id,
variables=variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=web_search,
memory=memory
):
# 解析原始事件
try:
lines = event_str.strip().split('\n')
event_type = None
event_data = None
for line in lines:
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
event_data = json.loads(line[6:])
# 从 start 事件中获取实际的 conversation_id
if event_type == "start" and event_data:
conv_id = event_data.get("conversation_id")
if conv_id:
returned_conversation_id = conv_id
# 累积消息内容
if event_type == "message" and event_data:
chunk = event_data.get("content", "")
full_content += chunk
# 转发消息块事件(带模型标识)
await event_queue.put(self._format_sse_event("model_message", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"content": chunk
}))
except Exception as e:
logger.warning(f"解析流式事件失败: {e}")
finally:
# 恢复原始参数
agent_config.model_parameters = original_params
elapsed = time.time() - start_time
# 构建结果(参考 run_compare
result = {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": returned_conversation_id,
"parameters_used": model_info["parameters"],
"message": full_content,
"elapsed_time": elapsed,
"error": None
}
# 发送模型完成事件
await event_queue.put(self._format_sse_event("model_end", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"elapsed_time": elapsed,
"message_length": len(full_content),
"timestamp": time.time()
}))
return result
except TimeoutError:
logger.warning(f"模型运行超时: {model_label}")
result = {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": model_conversation_id,
"parameters_used": model_info["parameters"],
"elapsed_time": timeout,
"error": f"执行超时({timeout}秒)"
}
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": model_conversation_id,
"error": result["error"],
"timestamp": time.time()
}))
return result
except Exception as e:
logger.error(f"模型运行失败: {model_label}, error: {e}")
result = {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": model_conversation_id,
"parameters_used": model_info["parameters"],
"elapsed_time": 0,
"error": str(e)
}
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": model_conversation_id,
"error": str(e),
"timestamp": time.time()
}))
return result
if parallel:
# 并行执行所有模型(参考 run_compare
logger.debug(f"并行执行 {len(models)} 个模型(流式)")
# 创建事件队列
event_queue = asyncio.Queue()
# 启动所有模型的并行任务
tasks = [
asyncio.create_task(run_single_model_stream(idx, model_info, event_queue))
for idx, model_info in enumerate(models)
]
# 持续从队列中取出事件并转发
completed_tasks = set()
while len(completed_tasks) < len(tasks):
try:
# 尝试从队列获取事件
event = await asyncio.wait_for(event_queue.get(), timeout=0.1)
yield event
except TimeoutError:
# 检查是否有任务完成
for task in tasks:
if task.done() and task not in completed_tasks:
completed_tasks.add(task)
try:
result = await task
if result:
results.append(result)
except Exception as e:
logger.error(f"获取任务结果失败: {e}")
continue
# 清空队列中剩余的事件
while not event_queue.empty():
try:
event = event_queue.get_nowait()
yield event
except asyncio.QueueEmpty:
break
else:
# 串行执行每个模型(参考 run_compare
logger.debug(f"串行执行 {len(models)} 个模型(流式)")
for idx, model_info in enumerate(models):
# 创建临时队列用于单个模型
event_queue = asyncio.Queue()
# 运行单个模型
result = await run_single_model_stream(idx, model_info, event_queue)
if result:
results.append(result)
# 转发该模型的所有事件
while not event_queue.empty():
try:
event = event_queue.get_nowait()
yield event
except asyncio.QueueEmpty:
break
# 统计分析(参考 run_compare
successful = [r for r in results if not r.get("error")]
failed = [r for r in results if r.get("error")]
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
cheapest = min(
successful,
key=lambda x: x.get("cost_estimate") or float("inf")
) if successful else None
# 构建结果摘要(包含完整的 message
results_summary = []
for r in results:
results_summary.append({
"model_config_id": str(r["model_config_id"]),
"model_name": r["model_name"],
"label": r["label"],
"conversation_id": r.get("conversation_id"),
"message": r.get("message"), # 包含完整消息
"elapsed_time": r.get("elapsed_time", 0),
"error": r.get("error")
})
# 发送对比完成事件(参考 run_compare 的返回格式)
yield self._format_sse_event("compare_end", {
"conversation_id": conversation_id,
"results": results_summary, # 包含完整结果
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful),
"failed_count": len(failed),
"fastest_model": fastest["label"] if fastest else None,
"cheapest_model": cheapest["label"] if cheapest else None,
"timestamp": time.time()
})
logger.info(
"多模型对比流式完成",
extra={
"successful": len(successful),
"failed": len(failed),
"total_time": sum(r.get("elapsed_time", 0) for r in results)
}
)
async def draft_run(
db: Session,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
user_id: Optional[str] = None,
kb_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.7,
top_k: int = 3
) -> Dict[str, Any]:
"""试运行 Agent便捷函数
Args:
db: 数据库会话
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
user_id: 用户ID
kb_ids: 知识库ID列表
similarity_threshold: 相似度阈值
top_k: 检索返回的文档数量
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
service = DraftRunService(db)
return await service.run(
agent_config=agent_config,
model_config=model_config,
message=message,
user_id=user_id,
kb_ids=kb_ids,
similarity_threshold=similarity_threshold,
top_k=top_k
)