""" 试运行服务 提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。 """ import asyncio import datetime import json import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional from langchain.agents import create_agent from langchain.tools import tool from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.memory.enums import SearchStrategy from app.core.memory.memory_service import MemoryService from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput, Citation, FileType from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services.conversation_service import ConversationService from app.services.langchain_tool_server import Search from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_service import ModelApiKeyService from app.services.multimodal_service import MultimodalService from app.services.tool_service import ToolService logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): """知识库检索工具输入参数""" query: str = Field(description="需要检索的问题或关键词") class WebSearchInput(BaseModel): """网络搜索工具输入参数""" query: str = Field(description="需要搜索的问题或关键词") class LongTermMemoryInput(BaseModel): """长期记忆工具输入参数""" question: str = Field( description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") def create_long_term_memory_tool( memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None ): """创建记忆工具, Args: memory_config: 记忆配置 end_user_id: 用户ID storage_type: 存储类型(可选) user_rag_memory_id: 用户RAG记忆ID(可选) Returns: 长期记忆工具 """ # search_switch = memory_config.get("search_switch", "2") # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None) logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") @tool(args_schema=LongTermMemoryInput) def long_term_memory(question: str) -> str: """ 从用户的历史记忆中检索相关信息。用于了解用户的背景、偏好和历史对话内容。 **何时使用此工具:** - 用户明确询问历史信息(如"我之前说过什么"、"上次我们聊了什么") - 用户询问个人信息或偏好(如"我喜欢什么"、"我的习惯是什么") - 需要基于历史上下文提供个性化建议 **何时不使用此工具:** - 简单问候(如"你好"、"谢谢"、"再见") - 纯任务性请求(如"写代码"、"翻译文字"、"分析图片") - 用户已提供完整信息(如提供了文本、图片、文档等内容) - 创作性任务(如"写诗"、"编故事"、"创作谜语") **重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。** Args: question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词,第三人称描述的偏好、行为通常指用户本人,比如(我,本人,在下,自己,咱,鄙人,吴,余)通指用户) Returns: 检索到的历史记忆内容 """ logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}") try: with get_db_context() as db: memory_service = MemoryService(db, config_id, end_user_id) search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK)) # memory_content = asyncio.run( # MemoryAgentService().read_memory( # end_user_id=end_user_id, # message=question, # history=[], # search_switch="2", # config_id=config_id, # db=db, # storage_type=storage_type, # user_rag_memory_id=user_rag_memory_id # ) # ) # task = celery_app.send_task( # "app.core.memory.agent.read_message", # args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] # ) # result = task_service.get_task_memory_read_result(task.id) # status = result.get("status") # logger.info(f"读取任务状态:{status}") # if memory_content: # memory_content = memory_content['answer'] # logger.info(f'用户ID:Agent:{end_user_id}') # logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) # # logger.info( # "长期记忆检索成功", # extra={ # "end_user_id": end_user_id, # "content_length": len(str(memory_content)) # } # ) return f"检索到以下历史记忆:\n\n{search_result.content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) return f"记忆检索失败: {str(e)}" return long_term_memory def create_web_search_tool(web_search_config: Dict[str, Any]): """创建网络搜索工具 Args: web_search_config: 网络搜索配置 Returns: 网络搜索工具 """ logger.info("创建网络搜索工具") @tool(args_schema=WebSearchInput) def web_search_tool(query: str) -> str: """从互联网搜索最新信息。当用户的问题需要实时信息、最新新闻或网络资料时,使用此工具进行搜索。 Args: query: 需要搜索的问题或关键词 Returns: 搜索到的相关网络信息 """ try: logger.info(f"执行网络搜索: {query}") # 调用搜索服务 search_result = Search(query) logger.info( "网络搜索成功", extra={ "query": query, "result_length": len(search_result) } ) return f"搜索到以下网络信息:\n\n{search_result}" except Exception as e: logger.error("网络搜索失败", extra={"error": str(e), "error_type": type(e).__name__}) return f"搜索失败: {str(e)}" return web_search_tool def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: kb_config: 知识库配置 kb_ids: 知识库ID列表 user_id: 用户ID citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) 列表元素类型为 Citation,包含字段: - document_id: 文档唯一标识 - file_name: 文件名 - knowledge_id: 知识库 ID - score: 检索相关性得分 Returns: 检索到的相关知识内容 """ logger.info(f"创建知识库检索工具,用户:{user_id}") @tool(args_schema=KnowledgeRetrievalInput) def knowledge_retrieval_tool(query: str) -> str: """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: query: 需要检索的问题或关键词 Returns: 检索到的相关知识内容 """ try: retrieve_chunks_result = knowledge_retrieval(query, kb_config) if retrieve_chunks_result: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] context = '\n\n'.join(retrieval_knowledge) logger.info( "知识库检索成功", extra={ "kb_ids": kb_ids, "result_count": len(retrieval_knowledge), "total_length": len(context) } ) # 收集引用信息 if citations_collector is not None: seen_doc_ids = {c.get("document_id") for c in citations_collector} for chunk in retrieve_chunks_result: meta = chunk.metadata or {} doc_id = meta.get("document_id") or meta.get("doc_id") if doc_id and doc_id not in seen_doc_ids: seen_doc_ids.add(doc_id) citations_collector.append(Citation( document_id=doc_id, file_name=meta.get("file_name", ""), knowledge_id=str(meta.get("knowledge_id", "")), score=meta.get("score", 0) )) return f"检索到以下相关信息:\n\n{context}" else: logger.warning("知识库检索未找到结果") return "未找到相关信息" except Exception as e: logger.error("知识库检索失败", extra={"error": str(e), "error_type": type(e).__name__}) return f"检索失败: {str(e)}" return knowledge_retrieval_tool class AgentRunService: """Agent运行服务类""" def __init__(self, db: Session): """Agent运行服务 Args: db: 数据库会话 """ self.db = db @staticmethod def prepare_variables( input_vars: dict | None, variables_config: dict ) -> dict: input_vars = input_vars or {} for variable in variables_config: if variable.get("required") and variable.get("name") not in input_vars: raise ValueError(f"The required parameter '{variable.get('name')}' was not provided") return input_vars def load_tools_config(self, tools_config, web_search, tenant_id) -> list: """加载工具配置""" tools = [] if web_search: search_tool = create_web_search_tool({}) tools.append(search_tool) if not tools_config: return tools tool_service = ToolService(self.db) if tools_config and isinstance(tools_config, list): for tool_config in tools_config: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) logger.debug( "已添加网络搜索工具", extra={ "tool_count": len(tools) } ) return tools def load_skill_config( self, skills_config: dict | None, message: str, tenant_id ) -> tuple[list, str]: if not skills_config: return [], "" tools = [] skill_prompts = "" skill_enable = skills_config.get("enabled", False) if skill_enable: middleware = AgentMiddleware(skills=skills_config) skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) tools.extend(skill_tools) logger.debug(f"已加载 {len(skill_tools)} 个技能工具") if skill_configs: tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) logger.debug(f"过滤后剩余 {len(tools)} 个工具") skill_prompts = AgentMiddleware.get_active_prompts( activated_skill_ids, skill_configs ) return tools, skill_prompts def load_knowledge_retrieval_config( self, knowledge_retrieval_config: dict | None, user_id ) -> tuple[list, list]: """返回 (tools, citations_collector)""" if not knowledge_retrieval_config: return [], [] citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: kb_tool = create_knowledge_retrieval_tool( knowledge_retrieval_config, kb_ids, user_id, citations_collector=citations_collector ) tools.append(kb_tool) logger.debug( "已添加知识库检索工具", extra={"kb_ids": kb_ids, "tool_count": len(tools)} ) return tools, citations_collector def load_memory_config( self, memory_config: dict | None, user_id, storage_type, user_rag_memory_id ) -> tuple[list, bool]: """加载长期记忆配置""" if not memory_config: return [], False tools = [] if memory_config.get("enabled"): if user_id: # 创建长期记忆工具 memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, user_rag_memory_id) tools.append(memory_tool) logger.debug( "已添加长期记忆工具", extra={ "user_id": user_id, "tool_count": len(tools) } ) return tools, bool(memory_config.get("enabled")) @staticmethod def _validate_file_upload( features_config: Dict[str, Any], files: Optional[List[FileInput]] ) -> None: """校验上传文件是否符合 file_upload 配置""" if not files or not features_config: return fu = features_config.get("file_upload", {}) if not (isinstance(fu, dict) and fu.get("enabled")): raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST) max_count = fu.get("max_file_count", 5) if len(files) > max_count: raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST) # 校验传输方式 allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"]) for f in files: if f.transfer_method.value not in allowed_methods: raise BusinessException( f"不支持的文件传输方式:{f.transfer_method.value},允许的方式:{', '.join(allowed_methods)}", BizCode.BAD_REQUEST ) # 各类型对应的开关和大小限制配置键 type_cfg = { "image": ("image_enabled", "image_max_size_mb", 20, "图片"), "audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"), "document": ("document_enabled", "document_max_size_mb", 100, "文档"), "video": ("video_enabled", "video_max_size_mb", 500, "视频"), } for f in files: ftype = str(f.type) # 如 "image", "audio", "document", "video" cfg = type_cfg.get(ftype) if cfg is None: continue enabled_key, size_key, default_max_mb, label = cfg # 校验类型开关 if not fu.get(enabled_key): raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST) # 校验文件大小(仅当内容已加载时) content = f.get_content() if content is not None: max_mb = fu.get(size_key, default_max_mb) size_mb = len(content) / (1024 * 1024) if size_mb > max_mb: raise BusinessException( f"{label}文件大小超过限制(最大 {max_mb}MB,当前 {size_mb:.1f}MB)", BizCode.BAD_REQUEST ) @staticmethod def _get_opening_statement( features_config: Dict[str, Any], is_new_conversation: bool, variables: Optional[Dict[str, Any]] = None ) -> tuple[Any, Any]: """首轮对话时返回开场白文本(支持变量替换),否则返回 None""" if not is_new_conversation: return None, None opening = features_config.get("opening_statement", {}) if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): return None, None statement = opening["statement"] suggested_questions = opening["suggested_questions"] # 如果有变量,进行替换(仅支持 {{var_name}} 格式) if variables: for var_name, var_value in variables.items(): placeholder = f"{{{{{var_name}}}}}" statement = statement.replace(placeholder, str(var_value)) return statement, suggested_questions @staticmethod def _filter_citations( features_config: Dict[str, Any], citations: List[Citation] ) -> List[Any]: """根据 citation 开关决定是否返回引用来源,并根据 allow_download 附加下载链接""" citation_cfg = features_config.get("citation", {}) if not (isinstance(citation_cfg, dict) and citation_cfg.get("enabled")): return [] allow_download = citation_cfg.get("allow_download", False) result = [] for cit in citations: item = cit.model_dump() if hasattr(cit, "model_dump") else dict(cit) if allow_download and item.get("document_id"): from app.core.config import settings item["download_url"] = f"{settings.FILE_LOCAL_SERVER_URL}/apps/citations/{item['document_id']}/download" result.append(item) return result async def run( self, *, agent_config: AgentConfig, model_config: ModelConfig, message: str, workspace_id: uuid.UUID, conversation_id: Optional[str] = None, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, web_search: bool = True, memory: bool = True, sub_agent: bool = False, files: Optional[List[FileInput]] = None # 新增:多模态文件 ) -> Dict[str, Any]: """执行试运行(使用 LangChain Agent) Args: agent_config: Agent 配置 model_config: 模型配置 message: 用户消息 workspace_id: 工作空间ID(必须,用于会话隔离) conversation_id: 会话ID(用于多轮对话) user_id: 用户ID variables: 自定义变量参数值 storage_type: 存储类型(可选) user_rag_memory_id: 用户RAG记忆ID(可选) web_search: 是否启用网络搜索(默认True) memory: 是否启用长期记忆(默认True) sub_agent: 是否为子代理调用(默认False) files: 多模态文件列表(可选) Returns: Dict: 包含 AI 回复和元数据的字典 """ start_time = time.time() tools_config: dict | list | None = agent_config.tools skills_config: dict | None = agent_config.skills knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval memory_config: dict | None = agent_config.memory features_config: dict = agent_config.features or {} # 从 features 中读取功能开关(优先级高于参数默认值) web_search_feature = features_config.get("web_search", {}) if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"): web_search = False # file_upload 校验 self._validate_file_upload(features_config, files) try: # 1. 获取 API Key 配置 api_key_config = await self._get_api_key(model_config.id) logger.debug( "API Key 配置获取成功", extra={ "model_name": api_key_config["model_name"], "has_api_key": bool(api_key_config["api_key"]), "has_api_base": bool(api_key_config.get("api_base")) } ) # 2. 合并模型参数 effective_params = ModelParameterMerger.get_effective_parameters( model_config=model_config, agent_config=agent_config ) if sub_agent: variables = self.prepare_variables(variables, agent_config.variables) else: # FIXME: subagent input valid variables = variables or {} system_prompt = render_prompt_message( agent_config.system_prompt, PromptMessageRole.USER, variables ) # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" # 4. 准备工具列表 tools = [] tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: memory_tools, memory_flag = self.load_memory_config( memory_config, user_id, storage_type, user_rag_memory_id ) tools.extend(memory_tools) # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id opening, suggested_questions = None, None if not sub_agent: opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, opening_statement=opening, suggested_questions=suggested_questions ) model_info = ModelInfo( model_name=api_key_config["model_name"], provider=api_key_config["provider"], api_key=api_key_config["api_key"], api_base=api_key_config["api_base"], capability=api_key_config["capability"], is_omni=api_key_config["is_omni"], model_type=model_config.type ) # 6. 加载历史消息(包含开场白) history = await self._load_conversation_history( conversation_id=conversation_id, max_history=10, current_provider=api_key_config.get("provider"), current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 processed_files = None has_doc_with_images = False if files: provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) fu_config = features_config.get("file_upload", {}) if hasattr(fu_config, "model_dump"): fu_config = fu_config.model_dump() doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False) processed_files = await multimodal_service.process_files( files, document_image_recognition=doc_img_recognition, workspace_id=workspace_id ) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") capability = api_key_config.get("capability", []) has_doc_with_images = ( doc_img_recognition and "vision" in capability and any(f.type == FileType.DOCUMENT for f in files) ) if has_doc_with_images: system_prompt += ( "\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: ," "请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。" "重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)," "必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。" ) agent = LangChainAgent( model_name=api_key_config["model_name"], api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, deep_thinking=effective_params.get("deep_thinking", False), thinking_budget_tokens=effective_params.get("thinking_budget_tokens"), json_output=effective_params.get("json_output", False), capability=api_key_config.get("capability", []), ) # 为需要运行时上下文的工具注入上下文 for t in tools: if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): t.tool_instance.set_runtime_context( user_id=user_id or "anonymous", conversation_id=str(conversation_id) if conversation_id else None, uploaded_files=processed_files or [] ) # 7. 知识库检索 context = None logger.debug( "准备调用 LangChain Agent", extra={ "model": api_key_config["model_name"], "has_history": bool(history), "has_context": bool(context), "has_files": bool(processed_files) } ) memory_config_ = agent_config.memory # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) # 8. 调用 Agent(支持多模态) result = await agent.chat( message=message, history=history, context=context, files=processed_files # 传递处理后的文件 ) elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) # 9. 生成 TTS audio_url(在保存消息前生成,以便一并存入 meta_data) audio_url = await self._generate_tts( features_config, result["content"], api_key_config, tenant_id=tenant_id, workspace_id=workspace_id ) if not sub_agent else None # 过滤 citations(只调用一次) filtered_citations = self._filter_citations(features_config, citations_collector) # 10. 保存会话消息 if not sub_agent: await self._save_conversation_message( conversation_id=conversation_id, user_message=message, assistant_message=result["content"], app_id=agent_config.app_id, user_id=user_id, meta_data={ "usage": result.get("usage", { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 }), "reasoning_content": result.get("reasoning_content") }, files=files, processed_files=processed_files, audio_url=audio_url, citations=filtered_citations, provider=api_key_config.get("provider"), is_omni=api_key_config.get("is_omni", False) ) response = { "message": result["content"], "reasoning_content": result.get("reasoning_content"), "conversation_id": conversation_id, "usage": result.get("usage", { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 }), "elapsed_time": elapsed_time, "suggested_questions": await self._generate_suggested_questions( features_config, result["content"], api_key_config, effective_params ) if not sub_agent else [], "citations": filtered_citations, "audio_url": audio_url, "audio_status": "pending" if audio_url else None } logger.info( "试运行完成", extra={ "model": model_config.name, "elapsed_time": elapsed_time, "message_length": len(result["content"]), "total_tokens": result.get("usage", {}).get("total_tokens", 0) } ) return response except Exception as e: logger.error("LangChain Agent 调用失败", extra={"error": str(e), "error_type": type(e).__name__}) raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) async def run_stream( self, *, agent_config: AgentConfig, model_config: ModelConfig, message: str, workspace_id: uuid.UUID, conversation_id: Optional[str] = None, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, web_search: bool = True, # 布尔类型默认值 memory: bool = True, # 布尔类型默认值 sub_agent: bool = False, # 是否是作为子Agent运行 files: Optional[List[FileInput]] = None # 新增:多模态文件 ) -> AsyncGenerator[str, None]: """执行试运行(流式返回,使用 LangChain Agent) Args: agent_config: Agent 配置 model_config: 模型配置 message: 用户消息 workspace_id: 工作空间ID(必须,用于会话隔离) conversation_id: 会话ID(用于多轮对话) user_id: 用户ID variables: 自定义变量参数值 Yields: str: SSE 格式的事件数据 """ tools_config: dict | list | None = agent_config.tools skills_config: dict | None = agent_config.skills knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval memory_config: dict | None = agent_config.memory features_config: dict = agent_config.features or {} # 从 features 中读取功能开关 web_search_feature = features_config.get("web_search", {}) if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")): web_search = False # file_upload 校验 self._validate_file_upload(features_config, files) start_time = time.time() try: # 1. 获取 API Key 配置 api_key_config = await self._get_api_key(model_config.id) if not sub_agent: variables = self.prepare_variables(variables, agent_config.variables) else: # FIXME: subagent input valid variables = variables or {} # 2. 合并模型参数 effective_params = ModelParameterMerger.get_effective_parameters( model_config=model_config, agent_config=agent_config ) items_params = variables system_prompt = render_prompt_message( agent_config.system_prompt, # 修正拼写错误 PromptMessageRole.USER, items_params ) # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" # 4. 准备工具列表 tools = [] tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type, user_rag_memory_id) tools.extend(memory_tools) # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id opening, suggested_questions = None, None if not sub_agent: opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, sub_agent=sub_agent, opening_statement=opening, suggested_questions=suggested_questions ) model_info = ModelInfo( model_name=api_key_config["model_name"], provider=api_key_config["provider"], api_key=api_key_config["api_key"], api_base=api_key_config["api_base"], capability=api_key_config["capability"], is_omni=api_key_config["is_omni"], model_type=model_config.type ) # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, max_history=memory_config.get("max_history", 10), current_provider=api_key_config.get("provider"), current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 processed_files = None has_doc_with_images = False if files: provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) fu_config = features_config.get("file_upload", {}) if hasattr(fu_config, "model_dump"): fu_config = fu_config.model_dump() doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False) processed_files = await multimodal_service.process_files( files, document_image_recognition=doc_img_recognition, workspace_id=workspace_id ) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") capability = api_key_config.get("capability", []) has_doc_with_images = ( doc_img_recognition and "vision" in capability and any(f.type == FileType.DOCUMENT for f in files) ) if has_doc_with_images: system_prompt += ( "\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: ," "请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。" "重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)," "必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。" ) # 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_config["model_name"], api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, streaming=True, deep_thinking=effective_params.get("deep_thinking", False), thinking_budget_tokens=effective_params.get("thinking_budget_tokens"), json_output=effective_params.get("json_output", False), capability=api_key_config.get("capability", []), ) # 为需要运行时上下文的工具注入上下文 for t in tools: if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): t.tool_instance.set_runtime_context( user_id=user_id or "anonymous", conversation_id=str(conversation_id) if conversation_id else None, uploaded_files=processed_files or [] ) # 7. 知识库检索 context = None # 8. 发送开始事件 yield self._format_sse_event("start", { "conversation_id": conversation_id, "timestamp": time.time() }) memory_config_ = agent_config.memory # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) # 9. 流式调用 Agent(支持多模态),同时并行启动 TTS full_content = "" full_reasoning = "" total_tokens = 0 # 启动流式 TTS(文本边输出边合成) text_queue: asyncio.Queue = asyncio.Queue() stream_audio_url, tts_task = await self._generate_tts_streaming( features_config, api_key_config, text_queue=text_queue, tenant_id=tenant_id, workspace_id=workspace_id ) if not sub_agent else (None, None) async for chunk in agent.chat_stream( message=message, history=history, context=context, files=processed_files ): if isinstance(chunk, int): total_tokens = chunk elif isinstance(chunk, dict) and chunk.get("type") == "reasoning": full_reasoning += chunk["content"] yield self._format_sse_event("reasoning", {"content": chunk["content"]}) else: full_content += chunk yield self._format_sse_event("message", {"content": chunk}) if tts_task is not None: await text_queue.put(chunk) # 文本结束,通知 TTS if tts_task is not None: await text_queue.put(None) elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) if sub_agent: yield self._format_sse_event("sub_usage", {"total_tokens": total_tokens}) # 过滤 citations(只调用一次) filtered_citations = self._filter_citations(features_config, citations_collector) # 11. 保存会话消息 if not sub_agent: await self._save_conversation_message( conversation_id=conversation_id, user_message=message, assistant_message=full_content, app_id=agent_config.app_id, user_id=user_id, meta_data={ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}, "reasoning_content": full_reasoning or None }, files=files, processed_files=processed_files, audio_url=stream_audio_url, citations=filtered_citations, provider=api_key_config.get("provider"), is_omni=api_key_config.get("is_omni", False) ) # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status) end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, "message_length": len(full_content) } if not sub_agent: end_data["suggested_questions"] = await self._generate_suggested_questions( features_config, full_content, api_key_config, effective_params ) end_data["audio_url"] = stream_audio_url # 检查TTS是否已完成(非阻塞,不取消任务) audio_status = "pending" if tts_task is not None and tts_task.done(): # 任务已完成,检查是否有异常 try: tts_task.result() audio_status = "completed" except Exception as e: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = filtered_citations yield self._format_sse_event("end", end_data) logger.info( "流式试运行完成", extra={ "model": model_config.name, "elapsed_time": elapsed_time, "message_length": len(full_content) } ) except Exception as e: logger.error("流式 Agent 调用失败", extra={"error": str(e)}, exc_info=True) # 发送错误事件 yield self._format_sse_event("error", { "error": str(e), "timestamp": time.time() }) def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str: """格式化 SSE 事件 Args: event: 事件类型 data: 事件数据 Returns: str: SSE 格式的字符串 """ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict: """获取模型的 API Key Args: model_config_id: 模型配置ID Returns: Dict: 包含 model_name, api_key, api_base 的字典 Raises: BusinessException: 当没有可用的 API Key 时 """ # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) # stmt = ( # select(ModelApiKey).join( # ModelConfig, ModelApiKey.model_configs # ) # .where( # ModelConfig.id == model_config_id, # ModelApiKey.is_active.is_(True) # ) # .order_by(ModelApiKey.priority.desc()) # .limit(1) # ) # # api_key = self.db.scalars(stmt).first() # api_key = api_keys[0] if api_keys else None api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) return { "model_name": api_key.model_name, "provider": api_key.provider, "api_key": api_key.api_key, "api_base": api_key.api_base, "api_key_id": api_key.id, "is_omni": api_key.is_omni, "capability": api_key.capability } async def _ensure_conversation( self, conversation_id: Optional[str], app_id: uuid.UUID, workspace_id: uuid.UUID, user_id: Optional[str], sub_agent: bool = False, opening_statement: Optional[str] = None, suggested_questions: Optional[List[str]] = None ) -> str: """确保会话存在(创建或验证) Args: conversation_id: 会话ID(可选) app_id: 应用ID workspace_id: 工作空间ID(必须) user_id: 用户ID sub_agent: 是否为子代理 opening_statement: 开场白(新会话时作为第一条消息写入) suggested_questions: 预设问题列表 Returns: str: 会话ID Raises: BusinessException: 当指定的会话不存在时 """ from app.models import Conversation as ConversationModel from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) # 如果没有提供会话ID,创建新会话 if not conversation_id: logger.info( "创建新的草稿会话", extra={"workspace_id": str(workspace_id)} ) # 获取配置快照 config_snapshot = await self._get_config_snapshot(app_id) # 创建新会话 new_conv_id = str(uuid.uuid4()) new_conversation = ConversationModel( id=uuid.UUID(new_conv_id), app_id=app_id, workspace_id=workspace_id, user_id=user_id, is_draft=True, title="草稿会话", config_snapshot=config_snapshot ) self.db.add(new_conversation) self.db.commit() self.db.refresh(new_conversation) # 如果有开场白,作为第一条 assistant 消息写入数据库 if opening_statement: conversation_service.add_message( conversation_id=uuid.UUID(new_conv_id), role="assistant", content=opening_statement, meta_data={"suggested_questions": suggested_questions} ) logger.debug(f"已保存开场白到会话 {new_conv_id}") logger.info( "创建草稿会话成功", extra={ "conversation_id": new_conv_id, "workspace_id": str(workspace_id) } ) return new_conv_id # 如果提供了会话ID,验证其存在性和工作空间归属 try: conv_uuid = uuid.UUID(conversation_id) conversation = conversation_service.get_conversation(conv_uuid) # 验证会话属于当前工作空间(或属于共享应用的源工作空间) # sub_agent 内部调用时跳过校验,已在上层验证过 if not sub_agent and conversation.workspace_id != workspace_id: # 检查是否是共享应用的会话(被共享者 workspace 访问源应用) from app.models import AppShare from sqlalchemy import select as sa_select share = self.db.scalars( sa_select(AppShare).where( AppShare.source_app_id == app_id, AppShare.target_workspace_id == workspace_id ) ).first() # 情况2:sub_agent 内部调用时,workspace_id 是源应用的 workspace, # 而会话是被共享者创建的,只要会话属于同一个 app 即可放行 same_app = (conversation.app_id == app_id) if not share and not same_app: logger.warning( "会话不属于当前工作空间", extra={ "conversation_id": conversation_id, "conversation_workspace_id": str(conversation.workspace_id), "current_workspace_id": str(workspace_id) } ) raise BusinessException( "会话不属于当前工作空间", BizCode.PERMISSION_DENIED ) logger.debug( "使用现有会话", extra={ "conversation_id": conversation_id, "workspace_id": str(workspace_id) } ) return conversation_id except BusinessException: raise except Exception as e: logger.error( "会话不存在或无效", extra={"conversation_id": conversation_id, "error": str(e)} ) raise BusinessException( f"会话不存在: {conversation_id}", BizCode.NOT_FOUND, cause=e ) async def _load_conversation_history( self, conversation_id: str, max_history: int = 10, current_provider: Optional[str] = None, current_is_omni: Optional[bool] = None ) -> List[Dict[str, str]]: """加载会话历史消息,并根据当前模型配置处理多模态文件 Args: conversation_id: 会话ID max_history: 最大历史消息数量 current_provider: 当前模型的provider current_is_omni: 当前模型的is_omni Returns: List[Dict]: 历史消息列表 """ try: conversation_service = ConversationService(self.db) # 获取 API 配置用于多模态处理 history = await conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), max_history=max_history, current_provider=current_provider, current_is_omni=current_is_omni ) logger.debug( "加载会话历史", extra={ "conversation_id": conversation_id, "max_history": max_history, "loaded_count": len(history) } ) return history except Exception as e: # 新会话没有历史记录是正常的 logger.debug("加载会话历史失败(可能是新会话)", extra={"error": str(e)}) return [] async def _save_conversation_message( self, conversation_id: str, user_message: str, assistant_message: str, meta_data: dict, app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None, files: Optional[List[FileInput]] = None, processed_files: Optional[List[Dict[str, Any]]] = None, audio_url: Optional[str] = None, citations: Optional[List[Any]] = None, provider: Optional[str] = None, is_omni: Optional[bool] = None ) -> None: """保存会话消息(会话已通过 _ensure_conversation 确保存在) Args: conversation_id: 会话ID user_message: 用户消息 assistant_message: AI 回复消息 app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) meta_data: token消耗 files: 原始文件输入 processed_files: 处理后的文件 audio_url: 音频URL citations: 引用来源列表 provider: 模型供应商 is_omni: 是否为全模态模型 """ try: from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) conv_uuid = uuid.UUID(conversation_id) # 保存消息(会话已经存在) human_meta = { "files": [], "history_files": {} } if files: from app.models.file_metadata_model import FileMetadata local_ids = [f.upload_file_id for f in files if f.transfer_method.value == "local_file" and f.upload_file_id and (not f.name or not f.size)] meta_map = {} if local_ids: rows = self.db.query(FileMetadata).filter( FileMetadata.id.in_(local_ids), FileMetadata.status == "completed" ).all() meta_map = {str(r.id): r for r in rows} for f in files: name, size = f.name, f.size if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): meta = meta_map.get(str(f.upload_file_id)) if meta: name = name or meta.file_name size = size or meta.file_size human_meta["files"].append({ "type": f.type, "url": f.url, "file_type": f.file_type, "name": name, "size": size }) # 保存 history_files,包含 provider 和 is_omni 信息 if processed_files: human_meta["history_files"] = { "content": processed_files, "provider": provider, "is_omni": is_omni } # 保存用户消息 conversation_service.add_message( conversation_id=conv_uuid, role="user", content=user_message, meta_data=human_meta ) # 保存助手消息(含 audio_url 和 citations) if audio_url: meta_data["audio_url"] = audio_url if citations: meta_data["citations"] = citations conversation_service.add_message( conversation_id=conv_uuid, role="assistant", content=assistant_message, meta_data=meta_data ) logger.debug( "保存会话消息", extra={ "conversation_id": conversation_id, "user_message_length": len(user_message), "assistant_message_length": len(assistant_message) } ) except Exception as e: logger.warning("保存会话消息失败", extra={"error": str(e)}) async def _get_config_snapshot(self, app_id: uuid.UUID) -> Dict[str, Any]: """获取当前配置快照 Args: app_id: 应用ID Returns: Dict: 配置快照 """ try: from app.models import AgentConfig, ModelConfig # 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: return {} # 获取模型配置 model_config = None if agent_cfg.default_model_config_id: model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) # 构建快照(确保所有值都可序列化) def safe_serialize(value): """安全序列化值""" if value is None: return None if isinstance(value, (str, int, float, bool)): return value if isinstance(value, (dict, list)): return value # 对于 Pydantic 模型或其他对象,尝试转换为字典 if hasattr(value, 'dict'): return value.dict() if hasattr(value, '__dict__'): return value.__dict__ return str(value) snapshot = { "agent_config": { "system_prompt": agent_cfg.system_prompt, "model_parameters": safe_serialize(agent_cfg.model_parameters), "knowledge_retrieval": safe_serialize(agent_cfg.knowledge_retrieval), "memory": safe_serialize(agent_cfg.memory), "variables": safe_serialize(agent_cfg.variables), "tools": safe_serialize(agent_cfg.tools) }, "model_config": { "model_name": model_config.name if model_config else None, "provider": model_config.provider if model_config else None, "type": model_config.type if model_config else None } if model_config else None, "snapshot_time": datetime.datetime.now().isoformat() } return snapshot except Exception as e: # 对于多 Agent 应用,没有直接的 AgentConfig 是正常的 logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)}) return {} async def _generate_suggested_questions( self, features_config: Dict[str, Any], assistant_message: str, api_key_config: Dict[str, Any], effective_params: Dict[str, Any] ) -> List[str]: """根据 suggested_questions_after_answer 配置生成下一步建议问题""" sq_config = features_config.get("suggested_questions_after_answer", {}) if not isinstance(sq_config, dict) or not sq_config.get("enabled"): return [] try: from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage llm = ChatOpenAI( model=api_key_config["model_name"], api_key=api_key_config["api_key"], base_url=api_key_config.get("api_base"), temperature=0.5, max_tokens=200, ) prompt = ( f"根据以下AI回复,生成3个用户可能继续追问的简短问题,每行一个,不加序号:\n\n{assistant_message}" ) resp = await llm.ainvoke([HumanMessage(content=prompt)]) lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()] return lines[:3] except Exception as e: logger.warning(f"生成建议问题失败: {e}") return [] async def _generate_tts( self, features_config: Dict[str, Any], text: str, api_key_config: Dict[str, Any], tenant_id: Optional[uuid.UUID] = None, workspace_id: Optional[uuid.UUID] = None, ) -> Optional[str]: """先注册文件元数据并返回 audio_url,再后台流式写入音频内容""" tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): return None if not text or not text.strip(): return None from app.models.file_metadata_model import FileMetadata from app.services.file_storage_service import FileStorageService, generate_file_key provider = api_key_config.get("provider", "openai") api_key = api_key_config.get("api_key") api_base = api_key_config.get("api_base") voice = tts_config.get("voice") file_ext, content_type = ".mp3", "audio/mpeg" file_id = uuid.uuid4() file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext) # 先写入 pending 状态的元数据,立即返回 URL db_file = FileMetadata( id=file_id, tenant_id=tenant_id, workspace_id=workspace_id, file_key=file_key, file_name=f"tts_{file_id}{file_ext}", file_ext=file_ext, file_size=0, content_type=content_type, status="pending", ) self.db.add(db_file) self.db.commit() server_url = settings.FILE_LOCAL_SERVER_URL audio_url = f"{server_url}/storage/permanent/{file_id}" # 后台任务:流式生成并写入存储,完成后更新状态 async def _stream_to_storage(): try: storage_service = FileStorageService() if provider == "dashscope": stream = self._tts_dashscope_stream( api_key=api_key, text=text, voice=voice or "longxiaochun", tts_config=tts_config, ) else: stream = self._tts_openai_stream( api_key=api_key, api_base=api_base, text=text, voice=voice or "alloy", ) total_size = await storage_service.upload_stream( tenant_id=tenant_id, workspace_id=workspace_id, file_id=file_id, file_ext=file_ext, stream=stream, content_type=content_type, ) # 更新元数据状态 with get_db_context() as bg_db: record = bg_db.get(FileMetadata, file_id) if record: record.status = "completed" record.file_size = total_size bg_db.commit() logger.debug(f"TTS 流式写入完成,provider={provider}, file_key={file_key}") except Exception as e: logger.warning(f"TTS 流式写入失败: {e}") with get_db_context() as bg_db: record = bg_db.get(FileMetadata, file_id) if record: record.status = "failed" bg_db.commit() asyncio.create_task(_stream_to_storage()) return audio_url async def _generate_tts_streaming( self, features_config: Dict[str, Any], api_key_config: Dict[str, Any], text_queue: asyncio.Queue, tenant_id: Optional[uuid.UUID] = None, workspace_id: Optional[uuid.UUID] = None, ) -> tuple[Optional[str], Optional[asyncio.Task]]: """文本流式输入并行合成音频。 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。 调用方向 text_queue put 文本 chunk,结束时 put None。 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。 """ tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): return None, None from app.models.file_metadata_model import FileMetadata from app.services.file_storage_service import FileStorageService, generate_file_key provider = api_key_config.get("provider", "openai") api_key = api_key_config.get("api_key") api_base = api_key_config.get("api_base") voice = tts_config.get("voice") file_ext, content_type = ".mp3", "audio/mpeg" file_id = uuid.uuid4() file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext) db_file = FileMetadata( id=file_id, tenant_id=tenant_id, workspace_id=workspace_id, file_key=file_key, file_name=f"tts_{file_id}{file_ext}", file_ext=file_ext, file_size=0, content_type=content_type, status="pending", ) self.db.add(db_file) self.db.commit() server_url = settings.FILE_LOCAL_SERVER_URL audio_url = f"{server_url}/storage/permanent/{file_id}" async def _run(): try: storage_service = FileStorageService() if provider == "dashscope": audio_stream = self._tts_dashscope_stream_from_queue( api_key=api_key, voice=voice or "longxiaochun", tts_config=tts_config, text_queue=text_queue, ) else: audio_stream = self._tts_openai_stream_from_queue( api_key=api_key, api_base=api_base, voice=voice or "alloy", text_queue=text_queue, ) total_size = await storage_service.upload_stream( tenant_id=tenant_id, workspace_id=workspace_id, file_id=file_id, file_ext=file_ext, stream=audio_stream, content_type=content_type, ) with get_db_context() as bg_db: record = bg_db.get(FileMetadata, file_id) if record: record.status = "completed" record.file_size = total_size bg_db.commit() logger.debug(f"TTS 流式合成完成,provider={provider}, file_key={file_key}") except Exception as e: logger.warning(f"TTS 流式合成失败: {e}") with get_db_context() as bg_db: record = bg_db.get(FileMetadata, file_id) if record: record.status = "failed" bg_db.commit() task = asyncio.create_task(_run()) return audio_url, task @staticmethod async def _tts_openai_stream_from_queue( api_key: str, api_base: Optional[str], voice: str, text_queue: asyncio.Queue, ): """OpenAI TTS:收集全部文本后流式合成(OpenAI 不支持增量输入)""" from openai import AsyncOpenAI # 收集全部文本(此时文本流已并行输出,等待时间短) parts = [] while True: chunk = await text_queue.get() if chunk is None: break parts.append(chunk) full_text = "".join(parts) if not full_text.strip(): return client = AsyncOpenAI(api_key=api_key, base_url=api_base) async with client.audio.speech.with_streaming_response.create( model="tts-1", voice=voice, input=full_text[:4096], ) as response: async for chunk in response.iter_bytes(chunk_size=4096): yield chunk @staticmethod async def _tts_dashscope_stream_from_queue( api_key: str, voice: str, tts_config: Dict[str, Any], text_queue: asyncio.Queue, ): """DashScope TTS:文本流式输入,实现真正并行合成""" import dashscope from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback model = tts_config.get("model") or "cosyvoice-v2" is_v2 = model.endswith("-v2") if is_v2 and not voice.endswith("_v2"): voice = voice + "_v2" elif not is_v2 and voice.endswith("_v2"): voice = voice[:-3] audio_queue: asyncio.Queue = asyncio.Queue() loop = asyncio.get_event_loop() class _Callback(ResultCallback): def on_data(self, data: bytes): if data: loop.call_soon_threadsafe(audio_queue.put_nowait, data) def on_complete(self): loop.call_soon_threadsafe(audio_queue.put_nowait, None) def on_error(self, message): loop.call_soon_threadsafe(audio_queue.put_nowait, RuntimeError(str(message))) def on_open(self): pass def on_close(self): pass dashscope.api_key = api_key synthesizer = SpeechSynthesizer( model=model, voice=voice, format=AudioFormat.MP3_22050HZ_MONO_256KBPS, callback=_Callback(), ) async def _feed_text(): """从 text_queue 取文本按句子切分后喂给 synthesizer""" import re buf = "" sentence_end = re.compile(r'[\u3002\uff01\uff1f.!?\n]') while True: chunk = await text_queue.get() if chunk is None: if buf.strip(): await asyncio.to_thread(synthesizer.streaming_call, buf) await asyncio.to_thread(synthesizer.streaming_complete) break buf += chunk # 按句子切分喂入 while sentence_end.search(buf): m = sentence_end.search(buf) sentence = buf[:m.end()] buf = buf[m.end():] await asyncio.to_thread(synthesizer.streaming_call, sentence) asyncio.create_task(_feed_text()) while True: item = await audio_queue.get() if item is None: break if isinstance(item, Exception): raise item yield item @staticmethod async def _tts_openai_stream( api_key: str, api_base: Optional[str], text: str, voice: str, ): """OpenAI 兼容 TTS 流式生成,yield bytes chunks""" from openai import AsyncOpenAI client = AsyncOpenAI(api_key=api_key, base_url=api_base) async with client.audio.speech.with_streaming_response.create( model="tts-1", voice=voice, input=text[:4096], ) as response: async for chunk in response.iter_bytes(chunk_size=4096): yield chunk @staticmethod async def _tts_dashscope_stream( api_key: str, text: str, voice: str, tts_config: Dict[str, Any], ): """DashScope TTS 流式生成,yield bytes chunks""" import dashscope from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback model = tts_config.get("model") or "cosyvoice-v2" is_v2 = model.endswith("-v2") if is_v2 and not voice.endswith("_v2"): voice = voice + "_v2" elif not is_v2 and voice.endswith("_v2"): voice = voice[:-3] queue: asyncio.Queue = asyncio.Queue() loop = asyncio.get_event_loop() class _Callback(ResultCallback): def on_data(self, data: bytes): if data: loop.call_soon_threadsafe(queue.put_nowait, data) def on_complete(self): loop.call_soon_threadsafe(queue.put_nowait, None) def on_error(self, message): loop.call_soon_threadsafe(queue.put_nowait, RuntimeError(str(message))) def on_open(self): pass def on_close(self): pass def _sync_stream(): dashscope.api_key = api_key synthesizer = SpeechSynthesizer( model=model, voice=voice, format=AudioFormat.MP3_22050HZ_MONO_256KBPS, callback=_Callback(), ) synthesizer.streaming_call(text[:4096]) synthesizer.streaming_complete() asyncio.create_task(asyncio.to_thread(_sync_stream)) while True: item = await queue.get() if item is None: break if isinstance(item, Exception): raise item yield item def _replace_variables( self, text: str, values: Dict[str, Any], definitions: List[Dict[str, Any]] ) -> str: """替换文本中的变量 Args: text: 原始文本 values: 变量值 definitions: 变量定义 Returns: str: 替换后的文本 """ result = text # 创建变量定义映射 var_defs = {var["name"]: var for var in definitions} for var_name, var_value in values.items(): # 检查变量是否在定义中 if var_name not in var_defs: logger.warning(f"未定义的变量: {var_name}") continue # 替换变量(支持多种格式) placeholders = [ f"{{{{{var_name}}}}}", # {{var_name}} f"{{{var_name}}}", # {var_name} f"${{{var_name}}}", # ${var_name} ] for placeholder in placeholders: if placeholder in result: result = result.replace(placeholder, str(var_value)) return result # ==================== 多模型对比试运行 ==================== async def run_compare( self, *, agent_config: AgentConfig, models: List[Dict[str, Any]], message: str, workspace_id: uuid.UUID, conversation_id: Optional[str] = None, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, parallel: bool = True, timeout: int = 60, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, web_search: bool = True, memory: bool = True, files: list[FileInput] | None = None ) -> Dict[str, Any]: """多模型对比试运行 Args: agent_config: Agent 配置 models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id message: 用户消息 workspace_id: 工作空间ID conversation_id: 会话ID user_id: 用户ID variables: 变量参数 parallel: 是否并行执行 timeout: 超时时间(秒) Returns: Dict: 对比结果 """ logger.info( "多模型对比试运行", extra={ "model_count": len(models), "parallel": parallel } ) # 提前校验文件上传(与 run() 内部保持一致) features_config: dict = agent_config.features or {} if hasattr(features_config, 'model_dump'): features_config = features_config.model_dump() # self._validate_file_upload(features_config, files) async def run_single_model(model_info): """运行单个模型""" try: start_time = time.time() # 临时修改参数(不使用 deepcopy 避免 SQLAlchemy 会话问题) original_params = agent_config.model_parameters agent_config.model_parameters = model_info["parameters"] # 使用模型自己的 conversation_id,如果没有则使用全局的 model_conversation_id = model_info.get("conversation_id") or conversation_id try: result = await asyncio.wait_for( self.run( agent_config=agent_config, model_config=model_info["model_config"], message=message, workspace_id=workspace_id, conversation_id=model_conversation_id, user_id=user_id, variables=variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, web_search=web_search, memory=memory, files=files ), timeout=timeout ) finally: # 恢复原始参数 agent_config.model_parameters = original_params elapsed = time.time() - start_time usage = result.get("usage", {}) return { "model_config_id": model_info["model_config_id"], "model_name": model_info["model_config"].name, "label": model_info["label"], "conversation_id": result['conversation_id'], "parameters_used": model_info["parameters"], "message": result.get("message"), "reasoning_content": result.get("reasoning_content"), "usage": usage, "elapsed_time": elapsed, "tokens_per_second": ( usage.get("completion_tokens", 0) / elapsed if elapsed > 0 and usage.get("completion_tokens") else None ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "audio_url": result.get("audio_url"), "audio_status": result.get("audio_status"), "citations": result.get("citations", []), "suggested_questions": result.get("suggested_questions", []), "error": None } except TimeoutError: logger.warning( "模型运行超时", extra={ "model_config_id": str(model_info["model_config_id"]), "timeout": timeout } ) return { "model_config_id": model_info["model_config_id"], "model_name": model_info["model_config"].name, "conversation_id": conversation_id, "label": model_info["label"], "parameters_used": model_info["parameters"], "elapsed_time": timeout, "error": f"执行超时({timeout}秒)" } except Exception as e: logger.error( "模型运行失败", extra={ "model_config_id": str(model_info["model_config_id"]), "error": str(e) } ) return { "model_config_id": model_info["model_config_id"], "model_name": model_info["model_config"].name, "label": model_info["label"], "conversation_id": conversation_id, "parameters_used": model_info["parameters"], "elapsed_time": 0, "error": str(e) } # 执行所有模型(并行或串行) if parallel: logger.debug(f"并行执行 {len(models)} 个模型") results = await asyncio.gather( *[run_single_model(m) for m in models], return_exceptions=False ) else: logger.debug(f"串行执行 {len(models)} 个模型") results = [] for model_info in models: result = await run_single_model(model_info) results.append(result) # 统计分析 successful = [r for r in results if not r.get("error")] failed = [r for r in results if r.get("error")] fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None cheapest = min( successful, key=lambda x: x.get("cost_estimate") or float("inf") ) if successful else None logger.info( "多模型对比完成", extra={ "successful": len(successful), "failed": len(failed), "total_time": sum(r.get("elapsed_time", 0) for r in results) } ) return { "results": [{ **r, "audio_url": r.get("audio_url"), "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), } for r in results], "total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results), "successful_count": len(successful), "failed_count": len(failed), "fastest_model": fastest["label"] if fastest else None, "cheapest_model": cheapest["label"] if cheapest else None } def _estimate_cost(self, usage: Dict[str, Any], model_config) -> Optional[float]: """估算成本 Args: usage: Token 使用情况 model_config: 模型配置 Returns: Optional[float]: 估算成本(美元) """ if not usage: return None prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) # 简化成本估算:暂时返回 None # TODO: 实现基于模型名称或配置的成本估算 # 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段 return None def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> tuple[AgentConfig, Any]: """创建一个带有覆盖参数的 agent_config(浅拷贝,只修改 model_parameters) Args: agent_config: 原始 Agent 配置 parameters: 要覆盖的参数 Returns: AgentConfig: 修改后的配置(注意:这是同一个对象,只是临时修改了 model_parameters) """ # 保存原始参数 original_params = agent_config.model_parameters # 设置新参数 agent_config.model_parameters = parameters return agent_config, original_params async def run_compare_stream( self, *, agent_config: AgentConfig, models: List[Dict[str, Any]], message: str, workspace_id: uuid.UUID, conversation_id: Optional[str] = None, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, web_search: bool = True, memory: bool = True, parallel: bool = True, timeout: int = 60, files: list[FileInput] | None = None ) -> AsyncGenerator[str, None]: """多模型对比试运行(流式返回) 参考 run_compare 的实现,支持并行或串行执行 Args: agent_config: Agent 配置 models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id message: 用户消息 workspace_id: 工作空间ID conversation_id: 会话ID user_id: 用户ID variables: 变量参数 storage_type: 存储类型 user_rag_memory_id: RAG 记忆 ID web_search: 是否启用网络搜索 memory: 是否启用记忆 parallel: 是否并行执行 timeout: 超时时间(秒) files: 多模态文件 Yields: str: SSE 格式的事件数据 """ logger.info( "多模型对比流式试运行", extra={"model_count": len(models), "parallel": parallel} ) # 提前校验文件上传 # features_config: dict = agent_config.features or {} # if hasattr(features_config, 'model_dump'): # features_config = features_config.model_dump() # self._validate_file_upload(features_config, files) # 发送开始事件 yield self._format_sse_event("compare_start", { "conversation_id": conversation_id, "model_count": len(models), "parallel": parallel, "timestamp": time.time() }) results = [] async def run_single_model_stream(idx: int, model_info: Dict[str, Any], event_queue: asyncio.Queue): """运行单个模型(流式)并将事件放入队列""" model_label = model_info["label"] model_config_id = str(model_info["model_config_id"]) # 使用模型自己的 conversation_id,如果没有则使用全局的 model_conversation_id = model_info.get("conversation_id") or conversation_id try: # 发送模型开始事件 await event_queue.put(self._format_sse_event("model_start", { "model_index": idx, "model_config_id": model_config_id, "model_name": model_info["model_config"].name, "label": model_label, "conversation_id": model_conversation_id, "timestamp": time.time() })) start_time = time.time() full_content = "" full_reasoning = "" returned_conversation_id = model_conversation_id audio_url = None audio_status = None citations = [] suggested_questions = [] # 临时修改参数 original_params = agent_config.model_parameters agent_config.model_parameters = model_info["parameters"] try: # 流式调用单个模型 async for event_str in self.run_stream( agent_config=agent_config, model_config=model_info["model_config"], message=message, workspace_id=workspace_id, conversation_id=model_conversation_id, user_id=user_id, variables=variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, web_search=web_search, memory=memory, files=files ): # 解析原始事件 try: lines = event_str.strip().split('\n') event_type = None event_data = None for line in lines: if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): event_data = json.loads(line[6:]) # 从 start 事件中获取实际的 conversation_id if event_type == "start" and event_data: conv_id = event_data.get("conversation_id") if conv_id: returned_conversation_id = conv_id # 累积消息内容 if event_type == "message" and event_data: chunk = event_data.get("content", "") full_content += chunk # 转发消息块事件(带模型标识) await event_queue.put(self._format_sse_event("model_message", { "model_index": idx, "model_config_id": model_config_id, "label": model_label, "conversation_id": returned_conversation_id, "content": chunk })) # 转发深度思考事件(带模型标识) if event_type == "reasoning" and event_data: reasoning_chunk = event_data.get("content", "") full_reasoning += reasoning_chunk await event_queue.put(self._format_sse_event("model_reasoning", { "model_index": idx, "model_config_id": model_config_id, "label": model_label, "conversation_id": returned_conversation_id, "content": event_data.get("content", "") })) # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") audio_status = event_data.get("audio_status") citations = event_data.get("citations", []) suggested_questions = event_data.get("suggested_questions", []) if event_type == "error" and event_data: await event_queue.put(self._format_sse_event("model_error", { "model_index": idx, "model_config_id": model_config_id, "label": model_label, "conversation_id": returned_conversation_id, "error": event_data.get("error", "未知错误") })) except Exception as e: logger.warning(f"解析流式事件失败: {e}") finally: # 恢复原始参数 agent_config.model_parameters = original_params elapsed = time.time() - start_time # 构建结果(参考 run_compare) result = { "model_config_id": model_info["model_config_id"], "model_name": model_info["model_config"].name, "label": model_label, "conversation_id": returned_conversation_id, "parameters_used": model_info["parameters"], "message": full_content, "reasoning_content": full_reasoning or None, "elapsed_time": elapsed, "audio_url": audio_url, "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "error": None } # 发送模型完成事件 await event_queue.put(self._format_sse_event("model_end", { "model_index": idx, "model_config_id": model_config_id, "label": model_label, "conversation_id": returned_conversation_id, "elapsed_time": elapsed, "message_length": len(full_content), "audio_url": audio_url, "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "timestamp": time.time() })) return result except TimeoutError: logger.warning(f"模型运行超时: {model_label}") result = { "model_config_id": model_info["model_config_id"], "model_name": model_info["model_config"].name, "label": model_label, "conversation_id": model_conversation_id, "parameters_used": model_info["parameters"], "elapsed_time": timeout, "error": f"执行超时({timeout}秒)" } await event_queue.put(self._format_sse_event("model_error", { "model_index": idx, "model_config_id": model_config_id, "label": model_label, "conversation_id": model_conversation_id, "error": result["error"], "timestamp": time.time() })) return result except Exception as e: logger.error(f"模型运行失败: {model_label}, error: {e}") result = { "model_config_id": model_info["model_config_id"], "model_name": model_info["model_config"].name, "label": model_label, "conversation_id": model_conversation_id, "parameters_used": model_info["parameters"], "elapsed_time": 0, "error": str(e) } await event_queue.put(self._format_sse_event("model_error", { "model_index": idx, "model_config_id": model_config_id, "label": model_label, "conversation_id": model_conversation_id, "error": str(e), "timestamp": time.time() })) return result if parallel: # 并行执行所有模型(参考 run_compare) logger.debug(f"并行执行 {len(models)} 个模型(流式)") # 创建事件队列 event_queue = asyncio.Queue() # 启动所有模型的并行任务 tasks = [ asyncio.create_task(run_single_model_stream(idx, model_info, event_queue)) for idx, model_info in enumerate(models) ] # 持续从队列中取出事件并转发 completed_tasks = set() while len(completed_tasks) < len(tasks): try: # 尝试从队列获取事件 event = await asyncio.wait_for(event_queue.get(), timeout=0.1) yield event except TimeoutError: # 检查是否有任务完成 for task in tasks: if task.done() and task not in completed_tasks: completed_tasks.add(task) try: result = await task if result: results.append(result) except Exception as e: logger.error(f"获取任务结果失败: {e}") continue # 清空队列中剩余的事件 while not event_queue.empty(): try: event = event_queue.get_nowait() yield event except asyncio.QueueEmpty: break else: # 串行执行每个模型(参考 run_compare) logger.debug(f"串行执行 {len(models)} 个模型(流式)") for idx, model_info in enumerate(models): # 创建临时队列用于单个模型 event_queue = asyncio.Queue() # 运行单个模型 result = await run_single_model_stream(idx, model_info, event_queue) if result: results.append(result) # 转发该模型的所有事件 while not event_queue.empty(): try: event = event_queue.get_nowait() yield event except asyncio.QueueEmpty: break # 统计分析(参考 run_compare) successful = [r for r in results if not r.get("error")] failed = [r for r in results if r.get("error")] fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None cheapest = min( successful, key=lambda x: x.get("cost_estimate") or float("inf") ) if successful else None # 构建结果摘要(包含完整的 message) results_summary = [] for r in results: results_summary.append({ "model_config_id": str(r["model_config_id"]), "model_name": r["model_name"], "label": r["label"], "conversation_id": r.get("conversation_id"), "message": r.get("message"), "reasoning_content": r.get("reasoning_content"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), "error": r.get("error") }) # 发送对比完成事件(参考 run_compare 的返回格式) yield self._format_sse_event("compare_end", { "conversation_id": conversation_id, "results": results_summary, # 包含完整结果 "total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results), "successful_count": len(successful), "failed_count": len(failed), "fastest_model": fastest["label"] if fastest else None, "cheapest_model": cheapest["label"] if cheapest else None, "timestamp": time.time() }) logger.info( "多模型对比流式完成", extra={ "successful": len(successful), "failed": len(failed), "total_time": sum(r.get("elapsed_time", 0) for r in results) } )