Files
MemoryBear/api/app/services/draft_run_service.py
Timebomb2018 faf8d1a51a fix(workflow): add reasoning content, suggested questions, citations and audio status support
- Introduce `reasoning_content`, `suggested_questions`, `citations`, and `audio_status` fields in conversation and app response schemas
- Conditionally set `audio_status` to `"pending"` only when `audio_url` is present
- Replace `model_dump` override with `@model_serializer(mode="wrap")` for cleaner serialization logic
- Change knowledge base validation failure from `RuntimeError` to warning + `continue` to avoid halting retrieval on invalid KB
2026-04-27 15:35:26 +08:00

2493 lines
101 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
试运行服务
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
"""
import asyncio
import datetime
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.agents import create_agent
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.memory.enums import SearchStrategy
from app.core.memory.memory_service import MemoryService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation, FileType
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.conversation_service import ConversationService
from app.services.langchain_tool_server import Search
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
"""知识库检索工具输入参数"""
query: str = Field(description="需要检索的问题或关键词")
class WebSearchInput(BaseModel):
"""网络搜索工具输入参数"""
query: str = Field(description="需要搜索的问题或关键词")
class LongTermMemoryInput(BaseModel):
"""长期记忆工具输入参数"""
question: str = Field(
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
def create_long_term_memory_tool(
memory_config: Dict[str, Any],
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
):
"""创建记忆工具,
Args:
memory_config: 记忆配置
end_user_id: 用户ID
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
Returns:
长期记忆工具
"""
# search_switch = memory_config.get("search_switch", "2")
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
@tool(args_schema=LongTermMemoryInput)
def long_term_memory(question: str) -> str:
"""
从用户的历史记忆中检索相关信息。用于了解用户的背景、偏好和历史对话内容。
**何时使用此工具:**
- 用户明确询问历史信息(如"我之前说过什么""上次我们聊了什么"
- 用户询问个人信息或偏好(如"我喜欢什么""我的习惯是什么"
- 需要基于历史上下文提供个性化建议
**何时不使用此工具:**
- 简单问候(如"你好""谢谢""再见"
- 纯任务性请求(如"写代码""翻译文字""分析图片"
- 用户已提供完整信息(如提供了文本、图片、文档等内容)
- 创作性任务(如"写诗""编故事""创作谜语"
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
Args:
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词,第三人称描述的偏好、行为通常指用户本人,比如(我,本人,在下,自己,咱,鄙人,吴,余)通指用户)
Returns:
检索到的历史记忆内容
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
with get_db_context() as db:
memory_service = MemoryService(db, config_id, end_user_id)
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
# memory_content = asyncio.run(
# MemoryAgentService().read_memory(
# end_user_id=end_user_id,
# message=question,
# history=[],
# search_switch="2",
# config_id=config_id,
# db=db,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# )
# task = celery_app.send_task(
# "app.core.memory.agent.read_message",
# args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
# )
# result = task_service.get_task_memory_read_result(task.id)
# status = result.get("status")
# logger.info(f"读取任务状态:{status}")
# if memory_content:
# memory_content = memory_content['answer']
# logger.info(f'用户IDAgent:{end_user_id}')
# logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
#
# logger.info(
# "长期记忆检索成功",
# extra={
# "end_user_id": end_user_id,
# "content_length": len(str(memory_content))
# }
# )
return f"检索到以下历史记忆:\n\n{search_result.content}"
except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"记忆检索失败: {str(e)}"
return long_term_memory
def create_web_search_tool(web_search_config: Dict[str, Any]):
"""创建网络搜索工具
Args:
web_search_config: 网络搜索配置
Returns:
网络搜索工具
"""
logger.info("创建网络搜索工具")
@tool(args_schema=WebSearchInput)
def web_search_tool(query: str) -> str:
"""从互联网搜索最新信息。当用户的问题需要实时信息、最新新闻或网络资料时,使用此工具进行搜索。
Args:
query: 需要搜索的问题或关键词
Returns:
搜索到的相关网络信息
"""
try:
logger.info(f"执行网络搜索: {query}")
# 调用搜索服务
search_result = Search(query)
logger.info(
"网络搜索成功",
extra={
"query": query,
"result_length": len(search_result)
}
)
return f"搜索到以下网络信息:\n\n{search_result}"
except Exception as e:
logger.error("网络搜索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"搜索失败: {str(e)}"
return web_search_tool
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
kb_config: 知识库配置
kb_ids: 知识库ID列表
user_id: 用户ID
citations_collector: 用于收集引用信息的列表由外部传入tool 执行时填充)
列表元素类型为 Citation包含字段
- document_id: 文档唯一标识
- file_name: 文件名
- knowledge_id: 知识库 ID
- score: 检索相关性得分
Returns:
检索到的相关知识内容
"""
logger.info(f"创建知识库检索工具,用户:{user_id}")
@tool(args_schema=KnowledgeRetrievalInput)
def knowledge_retrieval_tool(query: str) -> str:
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
query: 需要检索的问题或关键词
Returns:
检索到的相关知识内容
"""
try:
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
if retrieve_chunks_result:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
context = '\n\n'.join(retrieval_knowledge)
logger.info(
"知识库检索成功",
extra={
"kb_ids": kb_ids,
"result_count": len(retrieval_knowledge),
"total_length": len(context)
}
)
# 收集引用信息
if citations_collector is not None:
seen_doc_ids = {c.get("document_id") for c in citations_collector}
for chunk in retrieve_chunks_result:
meta = chunk.metadata or {}
doc_id = meta.get("document_id") or meta.get("doc_id")
if doc_id and doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
citations_collector.append(Citation(
document_id=doc_id,
file_name=meta.get("file_name", ""),
knowledge_id=str(meta.get("knowledge_id", "")),
score=meta.get("score", 0)
))
return f"检索到以下相关信息:\n\n{context}"
else:
logger.warning("知识库检索未找到结果")
return "未找到相关信息"
except Exception as e:
logger.error("知识库检索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"检索失败: {str(e)}"
return knowledge_retrieval_tool
class AgentRunService:
"""Agent运行服务类"""
def __init__(self, db: Session):
"""Agent运行服务
Args:
db: 数据库会话
"""
self.db = db
@staticmethod
def prepare_variables(
input_vars: dict | None,
variables_config: dict
) -> dict:
input_vars = input_vars or {}
for variable in variables_config:
if variable.get("required") and variable.get("name") not in input_vars:
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
return input_vars
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置"""
tools = []
if web_search:
search_tool = create_web_search_tool({})
tools.append(search_tool)
if not tools_config:
return tools
tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list):
for tool_config in tools_config:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools
def load_skill_config(
self,
skills_config: dict | None,
message: str, tenant_id
) -> tuple[list, str]:
if not skills_config:
return [], ""
tools = []
skill_prompts = ""
skill_enable = skills_config.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills_config)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
skill_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
return tools, skill_prompts
def load_knowledge_retrieval_config(
self,
knowledge_retrieval_config: dict | None,
user_id
) -> tuple[list, list]:
"""返回 (tools, citations_collector)"""
if not knowledge_retrieval_config:
return [], []
citations_collector = []
tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(
knowledge_retrieval_config, kb_ids, user_id,
citations_collector=citations_collector
)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={"kb_ids": kb_ids, "tool_count": len(tools)}
)
return tools, citations_collector
def load_memory_config(
self,
memory_config: dict | None,
user_id,
storage_type,
user_rag_memory_id
) -> tuple[list, bool]:
"""加载长期记忆配置"""
if not memory_config:
return [], False
tools = []
if memory_config.get("enabled"):
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
return tools, bool(memory_config.get("enabled"))
@staticmethod
def _validate_file_upload(
features_config: Dict[str, Any],
files: Optional[List[FileInput]]
) -> None:
"""校验上传文件是否符合 file_upload 配置"""
if not files or not features_config:
return
fu = features_config.get("file_upload", {})
if not (isinstance(fu, dict) and fu.get("enabled")):
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
max_count = fu.get("max_file_count", 5)
if len(files) > max_count:
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
# 校验传输方式
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
for f in files:
if f.transfer_method.value not in allowed_methods:
raise BusinessException(
f"不支持的文件传输方式:{f.transfer_method.value},允许的方式:{', '.join(allowed_methods)}",
BizCode.BAD_REQUEST
)
# 各类型对应的开关和大小限制配置键
type_cfg = {
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
}
for f in files:
ftype = str(f.type) # 如 "image", "audio", "document", "video"
cfg = type_cfg.get(ftype)
if cfg is None:
continue
enabled_key, size_key, default_max_mb, label = cfg
# 校验类型开关
if not fu.get(enabled_key):
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
# 校验文件大小(仅当内容已加载时)
content = f.get_content()
if content is not None:
max_mb = fu.get(size_key, default_max_mb)
size_mb = len(content) / (1024 * 1024)
if size_mb > max_mb:
raise BusinessException(
f"{label}文件大小超过限制(最大 {max_mb}MB当前 {size_mb:.1f}MB",
BizCode.BAD_REQUEST
)
@staticmethod
def _get_opening_statement(
features_config: Dict[str, Any],
is_new_conversation: bool,
variables: Optional[Dict[str, Any]] = None
) -> tuple[Any, Any]:
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
if not is_new_conversation:
return None, None
opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return None, None
statement = opening["statement"]
suggested_questions = opening["suggested_questions"]
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
if variables:
for var_name, var_value in variables.items():
placeholder = f"{{{{{var_name}}}}}"
statement = statement.replace(placeholder, str(var_value))
return statement, suggested_questions
@staticmethod
def _filter_citations(
features_config: Dict[str, Any],
citations: List[Citation]
) -> List[Any]:
"""根据 citation 开关决定是否返回引用来源,并根据 allow_download 附加下载链接"""
citation_cfg = features_config.get("citation", {})
if not (isinstance(citation_cfg, dict) and citation_cfg.get("enabled")):
return []
allow_download = citation_cfg.get("allow_download", False)
result = []
for cit in citations:
item = cit.model_dump() if hasattr(cit, "model_dump") else dict(cit)
if allow_download and item.get("document_id"):
from app.core.config import settings
item["download_url"] = f"{settings.FILE_LOCAL_SERVER_URL}/apps/citations/{item['document_id']}/download"
result.append(item)
return result
async def run(
self,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True,
sub_agent: bool = False,
files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> Dict[str, Any]:
"""执行试运行(使用 LangChain Agent
Args:
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
workspace_id: 工作空间ID必须用于会话隔离
conversation_id: 会话ID用于多轮对话
user_id: 用户ID
variables: 自定义变量参数值
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
web_search: 是否启用网络搜索默认True
memory: 是否启用长期记忆默认True
sub_agent: 是否为子代理调用默认False
files: 多模态文件列表(可选)
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
start_time = time.time()
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关(优先级高于参数默认值)
web_search_feature = features_config.get("web_search", {})
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
try:
# 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id)
logger.debug(
"API Key 配置获取成功",
extra={
"model_name": api_key_config["model_name"],
"has_api_key": bool(api_key_config["api_key"]),
"has_api_base": bool(api_key_config.get("api_base"))
}
)
# 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters(
model_config=model_config,
agent_config=agent_config
)
if sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
system_prompt = render_prompt_message(
agent_config.system_prompt,
PromptMessageRole.USER,
variables
)
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# 4. 准备工具列表
tools = []
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
if memory:
memory_tools, memory_flag = self.load_memory_config(
memory_config, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id,
opening_statement=opening,
suggested_questions=suggested_questions
)
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=model_config.type
)
# 6. 加载历史消息(包含开场白)
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=10,
current_provider=api_key_config.get("provider"),
current_is_omni=api_key_config.get("is_omni", False)
)
# 6. 处理多模态文件
processed_files = None
has_doc_with_images = False
if files:
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
fu_config = features_config.get("file_upload", {})
if hasattr(fu_config, "model_dump"):
fu_config = fu_config.model_dump()
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
processed_files = await multimodal_service.process_files(
files, document_image_recognition=doc_img_recognition,
workspace_id=workspace_id
)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
capability = api_key_config.get("capability", [])
has_doc_with_images = (
doc_img_recognition
and "vision" in capability
and any(f.type == FileType.DOCUMENT for f in files)
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
t.tool_instance.set_runtime_context(
user_id=user_id or "anonymous",
conversation_id=str(conversation_id) if conversation_id else None,
uploaded_files=processed_files or []
)
# 7. 知识库检索
context = None
logger.debug(
"准备调用 LangChain Agent",
extra={
"model": api_key_config["model_name"],
"has_history": bool(history),
"has_context": bool(context),
"has_files": bool(processed_files)
}
)
memory_config_ = agent_config.memory
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
# 8. 调用 Agent支持多模态
result = await agent.chat(
message=message,
history=history,
context=context,
files=processed_files # 传递处理后的文件
)
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 生成 TTS audio_url在保存消息前生成以便一并存入 meta_data
audio_url = await self._generate_tts(
features_config, result["content"], api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
) if not sub_agent else None
# 过滤 citations只调用一次
filtered_citations = self._filter_citations(features_config, citations_collector)
# 10. 保存会话消息
if not sub_agent:
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
assistant_message=result["content"],
app_id=agent_config.app_id,
user_id=user_id,
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"reasoning_content": result.get("reasoning_content")
},
files=files,
processed_files=processed_files,
audio_url=audio_url,
citations=filtered_citations,
provider=api_key_config.get("provider"),
is_omni=api_key_config.get("is_omni", False)
)
response = {
"message": result["content"],
"reasoning_content": result.get("reasoning_content"),
"conversation_id": conversation_id,
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time,
"suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [],
"citations": filtered_citations,
"audio_url": audio_url,
"audio_status": "pending" if audio_url else None
}
logger.info(
"试运行完成",
extra={
"model": model_config.name,
"elapsed_time": elapsed_time,
"message_length": len(result["content"]),
"total_tokens": result.get("usage", {}).get("total_tokens", 0)
}
)
return response
except Exception as e:
logger.error("LangChain Agent 调用失败", extra={"error": str(e), "error_type": type(e).__name__})
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
async def run_stream(
self,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True, # 布尔类型默认值
memory: bool = True, # 布尔类型默认值
sub_agent: bool = False, # 是否是作为子Agent运行
files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
"""执行试运行(流式返回,使用 LangChain Agent
Args:
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
workspace_id: 工作空间ID必须用于会话隔离
conversation_id: 会话ID用于多轮对话
user_id: 用户ID
variables: 自定义变量参数值
Yields:
str: SSE 格式的事件数据
"""
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
start_time = time.time()
try:
# 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id)
if not sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
# 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters(
model_config=model_config,
agent_config=agent_config
)
items_params = variables
system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误
PromptMessageRole.USER,
items_params
)
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# 4. 准备工具列表
tools = []
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
if memory:
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.extend(memory_tools)
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id,
sub_agent=sub_agent,
opening_statement=opening,
suggested_questions=suggested_questions
)
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=model_config.type
)
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=memory_config.get("max_history", 10),
current_provider=api_key_config.get("provider"),
current_is_omni=api_key_config.get("is_omni", False)
)
# 6. 处理多模态文件
processed_files = None
has_doc_with_images = False
if files:
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
fu_config = features_config.get("file_upload", {})
if hasattr(fu_config, "model_dump"):
fu_config = fu_config.model_dump()
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
processed_files = await multimodal_service.process_files(
files, document_image_recognition=doc_img_recognition,
workspace_id=workspace_id
)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
capability = api_key_config.get("capability", [])
has_doc_with_images = (
doc_img_recognition
and "vision" in capability
and any(f.type == FileType.DOCUMENT for f in files)
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=effective_params.get("deep_thinking", False),
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
json_output=effective_params.get("json_output", False),
capability=api_key_config.get("capability", []),
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
t.tool_instance.set_runtime_context(
user_id=user_id or "anonymous",
conversation_id=str(conversation_id) if conversation_id else None,
uploaded_files=processed_files or []
)
# 7. 知识库检索
context = None
# 8. 发送开始事件
yield self._format_sse_event("start", {
"conversation_id": conversation_id,
"timestamp": time.time()
})
memory_config_ = agent_config.memory
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
# 9. 流式调用 Agent支持多模态同时并行启动 TTS
full_content = ""
full_reasoning = ""
total_tokens = 0
# 启动流式 TTS文本边输出边合成
text_queue: asyncio.Queue = asyncio.Queue()
stream_audio_url, tts_task = await self._generate_tts_streaming(
features_config, api_key_config,
text_queue=text_queue,
tenant_id=tenant_id, workspace_id=workspace_id
) if not sub_agent else (None, None)
async for chunk in agent.chat_stream(
message=message,
history=history,
context=context,
files=processed_files
):
if isinstance(chunk, int):
total_tokens = chunk
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
full_reasoning += chunk["content"]
yield self._format_sse_event("reasoning", {"content": chunk["content"]})
else:
full_content += chunk
yield self._format_sse_event("message", {"content": chunk})
if tts_task is not None:
await text_queue.put(chunk)
# 文本结束,通知 TTS
if tts_task is not None:
await text_queue.put(None)
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
if sub_agent:
yield self._format_sse_event("sub_usage", {"total_tokens": total_tokens})
# 过滤 citations只调用一次
filtered_citations = self._filter_citations(features_config, citations_collector)
# 11. 保存会话消息
if not sub_agent:
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
assistant_message=full_content,
app_id=agent_config.app_id,
user_id=user_id,
meta_data={
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
"reasoning_content": full_reasoning or None
},
files=files,
processed_files=processed_files,
audio_url=stream_audio_url,
citations=filtered_citations,
provider=api_key_config.get("provider"),
is_omni=api_key_config.get("is_omni", False)
)
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status
end_data: Dict[str, Any] = {
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
if not sub_agent:
end_data["suggested_questions"] = await self._generate_suggested_questions(
features_config, full_content, api_key_config, effective_params
)
end_data["audio_url"] = stream_audio_url
# 检查TTS是否已完成非阻塞不取消任务
audio_status = "pending"
if tts_task is not None and tts_task.done():
# 任务已完成,检查是否有异常
try:
tts_task.result()
audio_status = "completed"
except Exception as e:
logger.warning(f"TTS任务异常: {e}")
audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = filtered_citations
yield self._format_sse_event("end", end_data)
logger.info(
"流式试运行完成",
extra={
"model": model_config.name,
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except Exception as e:
logger.error("流式 Agent 调用失败", extra={"error": str(e)}, exc_info=True)
# 发送错误事件
yield self._format_sse_event("error", {
"error": str(e),
"timestamp": time.time()
})
def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str:
"""格式化 SSE 事件
Args:
event: 事件类型
data: 事件数据
Returns:
str: SSE 格式的字符串
"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
"""获取模型的 API Key
Args:
model_config_id: 模型配置ID
Returns:
Dict: 包含 model_name, api_key, api_base 的字典
Raises:
BusinessException: 当没有可用的 API Key 时
"""
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# stmt = (
# select(ModelApiKey).join(
# ModelConfig, ModelApiKey.model_configs
# )
# .where(
# ModelConfig.id == model_config_id,
# ModelApiKey.is_active.is_(True)
# )
# .order_by(ModelApiKey.priority.desc())
# .limit(1)
# )
#
# api_key = self.db.scalars(stmt).first()
# api_key = api_keys[0] if api_keys else None
api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
return {
"model_name": api_key.model_name,
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id,
"is_omni": api_key.is_omni,
"capability": api_key.capability
}
async def _ensure_conversation(
self,
conversation_id: Optional[str],
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str],
sub_agent: bool = False,
opening_statement: Optional[str] = None,
suggested_questions: Optional[List[str]] = None
) -> str:
"""确保会话存在(创建或验证)
Args:
conversation_id: 会话ID可选
app_id: 应用ID
workspace_id: 工作空间ID必须
user_id: 用户ID
sub_agent: 是否为子代理
opening_statement: 开场白(新会话时作为第一条消息写入)
suggested_questions: 预设问题列表
Returns:
str: 会话ID
Raises:
BusinessException: 当指定的会话不存在时
"""
from app.models import Conversation as ConversationModel
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
# 如果没有提供会话ID创建新会话
if not conversation_id:
logger.info(
"创建新的草稿会话",
extra={"workspace_id": str(workspace_id)}
)
# 获取配置快照
config_snapshot = await self._get_config_snapshot(app_id)
# 创建新会话
new_conv_id = str(uuid.uuid4())
new_conversation = ConversationModel(
id=uuid.UUID(new_conv_id),
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
is_draft=True,
title="草稿会话",
config_snapshot=config_snapshot
)
self.db.add(new_conversation)
self.db.commit()
self.db.refresh(new_conversation)
# 如果有开场白,作为第一条 assistant 消息写入数据库
if opening_statement:
conversation_service.add_message(
conversation_id=uuid.UUID(new_conv_id),
role="assistant",
content=opening_statement,
meta_data={"suggested_questions": suggested_questions}
)
logger.debug(f"已保存开场白到会话 {new_conv_id}")
logger.info(
"创建草稿会话成功",
extra={
"conversation_id": new_conv_id,
"workspace_id": str(workspace_id)
}
)
return new_conv_id
# 如果提供了会话ID验证其存在性和工作空间归属
try:
conv_uuid = uuid.UUID(conversation_id)
conversation = conversation_service.get_conversation(conv_uuid)
# 验证会话属于当前工作空间(或属于共享应用的源工作空间)
# sub_agent 内部调用时跳过校验,已在上层验证过
if not sub_agent and conversation.workspace_id != workspace_id:
# 检查是否是共享应用的会话(被共享者 workspace 访问源应用)
from app.models import AppShare
from sqlalchemy import select as sa_select
share = self.db.scalars(
sa_select(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.target_workspace_id == workspace_id
)
).first()
# 情况2sub_agent 内部调用时workspace_id 是源应用的 workspace
# 而会话是被共享者创建的,只要会话属于同一个 app 即可放行
same_app = (conversation.app_id == app_id)
if not share and not same_app:
logger.warning(
"会话不属于当前工作空间",
extra={
"conversation_id": conversation_id,
"conversation_workspace_id": str(conversation.workspace_id),
"current_workspace_id": str(workspace_id)
}
)
raise BusinessException(
"会话不属于当前工作空间",
BizCode.PERMISSION_DENIED
)
logger.debug(
"使用现有会话",
extra={
"conversation_id": conversation_id,
"workspace_id": str(workspace_id)
}
)
return conversation_id
except BusinessException:
raise
except Exception as e:
logger.error(
"会话不存在或无效",
extra={"conversation_id": conversation_id, "error": str(e)}
)
raise BusinessException(
f"会话不存在: {conversation_id}",
BizCode.NOT_FOUND,
cause=e
)
async def _load_conversation_history(
self,
conversation_id: str,
max_history: int = 10,
current_provider: Optional[str] = None,
current_is_omni: Optional[bool] = None
) -> List[Dict[str, str]]:
"""加载会话历史消息,并根据当前模型配置处理多模态文件
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
current_provider: 当前模型的provider
current_is_omni: 当前模型的is_omni
Returns:
List[Dict]: 历史消息列表
"""
try:
conversation_service = ConversationService(self.db)
# 获取 API 配置用于多模态处理
history = await conversation_service.get_conversation_history(
conversation_id=uuid.UUID(conversation_id),
max_history=max_history,
current_provider=current_provider,
current_is_omni=current_is_omni
)
logger.debug(
"加载会话历史",
extra={
"conversation_id": conversation_id,
"max_history": max_history,
"loaded_count": len(history)
}
)
return history
except Exception as e:
# 新会话没有历史记录是正常的
logger.debug("加载会话历史失败(可能是新会话)", extra={"error": str(e)})
return []
async def _save_conversation_message(
self,
conversation_id: str,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
files: Optional[List[FileInput]] = None,
processed_files: Optional[List[Dict[str, Any]]] = None,
audio_url: Optional[str] = None,
citations: Optional[List[Any]] = None,
provider: Optional[str] = None,
is_omni: Optional[bool] = None
) -> None:
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复消息
app_id: 应用ID未使用保留用于兼容性
user_id: 用户ID未使用保留用于兼容性
meta_data: token消耗
files: 原始文件输入
processed_files: 处理后的文件
audio_url: 音频URL
citations: 引用来源列表
provider: 模型供应商
is_omni: 是否为全模态模型
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
conv_uuid = uuid.UUID(conversation_id)
# 保存消息(会话已经存在)
human_meta = {
"files": [],
"history_files": {}
}
if files:
from app.models.file_metadata_model import FileMetadata
local_ids = [f.upload_file_id for f in files
if f.transfer_method.value == "local_file" and f.upload_file_id
and (not f.name or not f.size)]
meta_map = {}
if local_ids:
rows = self.db.query(FileMetadata).filter(
FileMetadata.id.in_(local_ids),
FileMetadata.status == "completed"
).all()
meta_map = {str(r.id): r for r in rows}
for f in files:
name, size = f.name, f.size
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
meta = meta_map.get(str(f.upload_file_id))
if meta:
name = name or meta.file_name
size = size or meta.file_size
human_meta["files"].append({
"type": f.type,
"url": f.url,
"file_type": f.file_type,
"name": name,
"size": size
})
# 保存 history_files包含 provider 和 is_omni 信息
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": provider,
"is_omni": is_omni
}
# 保存用户消息
conversation_service.add_message(
conversation_id=conv_uuid,
role="user",
content=user_message,
meta_data=human_meta
)
# 保存助手消息(含 audio_url 和 citations
if audio_url:
meta_data["audio_url"] = audio_url
if citations:
meta_data["citations"] = citations
conversation_service.add_message(
conversation_id=conv_uuid,
role="assistant",
content=assistant_message,
meta_data=meta_data
)
logger.debug(
"保存会话消息",
extra={
"conversation_id": conversation_id,
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
except Exception as e:
logger.warning("保存会话消息失败", extra={"error": str(e)})
async def _get_config_snapshot(self, app_id: uuid.UUID) -> Dict[str, Any]:
"""获取当前配置快照
Args:
app_id: 应用ID
Returns:
Dict: 配置快照
"""
try:
from app.models import AgentConfig, ModelConfig
# 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
return {}
# 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
# 构建快照(确保所有值都可序列化)
def safe_serialize(value):
"""安全序列化值"""
if value is None:
return None
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (dict, list)):
return value
# 对于 Pydantic 模型或其他对象,尝试转换为字典
if hasattr(value, 'dict'):
return value.dict()
if hasattr(value, '__dict__'):
return value.__dict__
return str(value)
snapshot = {
"agent_config": {
"system_prompt": agent_cfg.system_prompt,
"model_parameters": safe_serialize(agent_cfg.model_parameters),
"knowledge_retrieval": safe_serialize(agent_cfg.knowledge_retrieval),
"memory": safe_serialize(agent_cfg.memory),
"variables": safe_serialize(agent_cfg.variables),
"tools": safe_serialize(agent_cfg.tools)
},
"model_config": {
"model_name": model_config.name if model_config else None,
"provider": model_config.provider if model_config else None,
"type": model_config.type if model_config else None
} if model_config else None,
"snapshot_time": datetime.datetime.now().isoformat()
}
return snapshot
except Exception as e:
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {}
async def _generate_suggested_questions(
self,
features_config: Dict[str, Any],
assistant_message: str,
api_key_config: Dict[str, Any],
effective_params: Dict[str, Any]
) -> List[str]:
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
sq_config = features_config.get("suggested_questions_after_answer", {})
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
return []
try:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
llm = ChatOpenAI(
model=api_key_config["model_name"],
api_key=api_key_config["api_key"],
base_url=api_key_config.get("api_base"),
temperature=0.5,
max_tokens=200,
)
prompt = (
f"根据以下AI回复生成3个用户可能继续追问的简短问题每行一个不加序号\n\n{assistant_message}"
)
resp = await llm.ainvoke([HumanMessage(content=prompt)])
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
return lines[:3]
except Exception as e:
logger.warning(f"生成建议问题失败: {e}")
return []
async def _generate_tts(
self,
features_config: Dict[str, Any],
text: str,
api_key_config: Dict[str, Any],
tenant_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
) -> Optional[str]:
"""先注册文件元数据并返回 audio_url再后台流式写入音频内容"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
return None
if not text or not text.strip():
return None
from app.models.file_metadata_model import FileMetadata
from app.services.file_storage_service import FileStorageService, generate_file_key
provider = api_key_config.get("provider", "openai")
api_key = api_key_config.get("api_key")
api_base = api_key_config.get("api_base")
voice = tts_config.get("voice")
file_ext, content_type = ".mp3", "audio/mpeg"
file_id = uuid.uuid4()
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
# 先写入 pending 状态的元数据,立即返回 URL
db_file = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=f"tts_{file_id}{file_ext}",
file_ext=file_ext,
file_size=0,
content_type=content_type,
status="pending",
)
self.db.add(db_file)
self.db.commit()
server_url = settings.FILE_LOCAL_SERVER_URL
audio_url = f"{server_url}/storage/permanent/{file_id}"
# 后台任务:流式生成并写入存储,完成后更新状态
async def _stream_to_storage():
try:
storage_service = FileStorageService()
if provider == "dashscope":
stream = self._tts_dashscope_stream(
api_key=api_key,
text=text,
voice=voice or "longxiaochun",
tts_config=tts_config,
)
else:
stream = self._tts_openai_stream(
api_key=api_key,
api_base=api_base,
text=text,
voice=voice or "alloy",
)
total_size = await storage_service.upload_stream(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
stream=stream,
content_type=content_type,
)
# 更新元数据状态
with get_db_context() as bg_db:
record = bg_db.get(FileMetadata, file_id)
if record:
record.status = "completed"
record.file_size = total_size
bg_db.commit()
logger.debug(f"TTS 流式写入完成provider={provider}, file_key={file_key}")
except Exception as e:
logger.warning(f"TTS 流式写入失败: {e}")
with get_db_context() as bg_db:
record = bg_db.get(FileMetadata, file_id)
if record:
record.status = "failed"
bg_db.commit()
asyncio.create_task(_stream_to_storage())
return audio_url
async def _generate_tts_streaming(
self,
features_config: Dict[str, Any],
api_key_config: Dict[str, Any],
text_queue: asyncio.Queue,
tenant_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
) -> tuple[Optional[str], Optional[asyncio.Task]]:
"""文本流式输入并行合成音频。
返回 (audio_url, task)audio_url 立即可用pending状态task 完成后文件内容就绪。
调用方向 text_queue put 文本 chunk结束时 put None。
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
return None, None
from app.models.file_metadata_model import FileMetadata
from app.services.file_storage_service import FileStorageService, generate_file_key
provider = api_key_config.get("provider", "openai")
api_key = api_key_config.get("api_key")
api_base = api_key_config.get("api_base")
voice = tts_config.get("voice")
file_ext, content_type = ".mp3", "audio/mpeg"
file_id = uuid.uuid4()
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
db_file = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=f"tts_{file_id}{file_ext}",
file_ext=file_ext,
file_size=0,
content_type=content_type,
status="pending",
)
self.db.add(db_file)
self.db.commit()
server_url = settings.FILE_LOCAL_SERVER_URL
audio_url = f"{server_url}/storage/permanent/{file_id}"
async def _run():
try:
storage_service = FileStorageService()
if provider == "dashscope":
audio_stream = self._tts_dashscope_stream_from_queue(
api_key=api_key,
voice=voice or "longxiaochun",
tts_config=tts_config,
text_queue=text_queue,
)
else:
audio_stream = self._tts_openai_stream_from_queue(
api_key=api_key,
api_base=api_base,
voice=voice or "alloy",
text_queue=text_queue,
)
total_size = await storage_service.upload_stream(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
stream=audio_stream,
content_type=content_type,
)
with get_db_context() as bg_db:
record = bg_db.get(FileMetadata, file_id)
if record:
record.status = "completed"
record.file_size = total_size
bg_db.commit()
logger.debug(f"TTS 流式合成完成provider={provider}, file_key={file_key}")
except Exception as e:
logger.warning(f"TTS 流式合成失败: {e}")
with get_db_context() as bg_db:
record = bg_db.get(FileMetadata, file_id)
if record:
record.status = "failed"
bg_db.commit()
task = asyncio.create_task(_run())
return audio_url, task
@staticmethod
async def _tts_openai_stream_from_queue(
api_key: str,
api_base: Optional[str],
voice: str,
text_queue: asyncio.Queue,
):
"""OpenAI TTS收集全部文本后流式合成OpenAI 不支持增量输入)"""
from openai import AsyncOpenAI
# 收集全部文本(此时文本流已并行输出,等待时间短)
parts = []
while True:
chunk = await text_queue.get()
if chunk is None:
break
parts.append(chunk)
full_text = "".join(parts)
if not full_text.strip():
return
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
async with client.audio.speech.with_streaming_response.create(
model="tts-1",
voice=voice,
input=full_text[:4096],
) as response:
async for chunk in response.iter_bytes(chunk_size=4096):
yield chunk
@staticmethod
async def _tts_dashscope_stream_from_queue(
api_key: str,
voice: str,
tts_config: Dict[str, Any],
text_queue: asyncio.Queue,
):
"""DashScope TTS文本流式输入实现真正并行合成"""
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
model = tts_config.get("model") or "cosyvoice-v2"
is_v2 = model.endswith("-v2")
if is_v2 and not voice.endswith("_v2"):
voice = voice + "_v2"
elif not is_v2 and voice.endswith("_v2"):
voice = voice[:-3]
audio_queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.get_event_loop()
class _Callback(ResultCallback):
def on_data(self, data: bytes):
if data:
loop.call_soon_threadsafe(audio_queue.put_nowait, data)
def on_complete(self):
loop.call_soon_threadsafe(audio_queue.put_nowait, None)
def on_error(self, message):
loop.call_soon_threadsafe(audio_queue.put_nowait, RuntimeError(str(message)))
def on_open(self): pass
def on_close(self): pass
dashscope.api_key = api_key
synthesizer = SpeechSynthesizer(
model=model,
voice=voice,
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
callback=_Callback(),
)
async def _feed_text():
"""从 text_queue 取文本按句子切分后喂给 synthesizer"""
import re
buf = ""
sentence_end = re.compile(r'[\u3002\uff01\uff1f.!?\n]')
while True:
chunk = await text_queue.get()
if chunk is None:
if buf.strip():
await asyncio.to_thread(synthesizer.streaming_call, buf)
await asyncio.to_thread(synthesizer.streaming_complete)
break
buf += chunk
# 按句子切分喂入
while sentence_end.search(buf):
m = sentence_end.search(buf)
sentence = buf[:m.end()]
buf = buf[m.end():]
await asyncio.to_thread(synthesizer.streaming_call, sentence)
asyncio.create_task(_feed_text())
while True:
item = await audio_queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
yield item
@staticmethod
async def _tts_openai_stream(
api_key: str,
api_base: Optional[str],
text: str,
voice: str,
):
"""OpenAI 兼容 TTS 流式生成yield bytes chunks"""
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
async with client.audio.speech.with_streaming_response.create(
model="tts-1",
voice=voice,
input=text[:4096],
) as response:
async for chunk in response.iter_bytes(chunk_size=4096):
yield chunk
@staticmethod
async def _tts_dashscope_stream(
api_key: str,
text: str,
voice: str,
tts_config: Dict[str, Any],
):
"""DashScope TTS 流式生成yield bytes chunks"""
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
model = tts_config.get("model") or "cosyvoice-v2"
is_v2 = model.endswith("-v2")
if is_v2 and not voice.endswith("_v2"):
voice = voice + "_v2"
elif not is_v2 and voice.endswith("_v2"):
voice = voice[:-3]
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.get_event_loop()
class _Callback(ResultCallback):
def on_data(self, data: bytes):
if data:
loop.call_soon_threadsafe(queue.put_nowait, data)
def on_complete(self):
loop.call_soon_threadsafe(queue.put_nowait, None)
def on_error(self, message):
loop.call_soon_threadsafe(queue.put_nowait, RuntimeError(str(message)))
def on_open(self): pass
def on_close(self): pass
def _sync_stream():
dashscope.api_key = api_key
synthesizer = SpeechSynthesizer(
model=model,
voice=voice,
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
callback=_Callback(),
)
synthesizer.streaming_call(text[:4096])
synthesizer.streaming_complete()
asyncio.create_task(asyncio.to_thread(_sync_stream))
while True:
item = await queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
yield item
def _replace_variables(
self,
text: str,
values: Dict[str, Any],
definitions: List[Dict[str, Any]]
) -> str:
"""替换文本中的变量
Args:
text: 原始文本
values: 变量值
definitions: 变量定义
Returns:
str: 替换后的文本
"""
result = text
# 创建变量定义映射
var_defs = {var["name"]: var for var in definitions}
for var_name, var_value in values.items():
# 检查变量是否在定义中
if var_name not in var_defs:
logger.warning(f"未定义的变量: {var_name}")
continue
# 替换变量(支持多种格式)
placeholders = [
f"{{{{{var_name}}}}}", # {{var_name}}
f"{{{var_name}}}", # {var_name}
f"${{{var_name}}}", # ${var_name}
]
for placeholder in placeholders:
if placeholder in result:
result = result.replace(placeholder, str(var_value))
return result
# ==================== 多模型对比试运行 ====================
async def run_compare(
self,
*,
agent_config: AgentConfig,
models: List[Dict[str, Any]],
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
parallel: bool = True,
timeout: int = 60,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True,
files: list[FileInput] | None = None
) -> Dict[str, Any]:
"""多模型对比试运行
Args:
agent_config: Agent 配置
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
message: 用户消息
workspace_id: 工作空间ID
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
parallel: 是否并行执行
timeout: 超时时间(秒)
Returns:
Dict: 对比结果
"""
logger.info(
"多模型对比试运行",
extra={
"model_count": len(models),
"parallel": parallel
}
)
# 提前校验文件上传(与 run() 内部保持一致)
features_config: dict = agent_config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
async def run_single_model(model_info):
"""运行单个模型"""
try:
start_time = time.time()
# 临时修改参数(不使用 deepcopy 避免 SQLAlchemy 会话问题)
original_params = agent_config.model_parameters
agent_config.model_parameters = model_info["parameters"]
# 使用模型自己的 conversation_id如果没有则使用全局的
model_conversation_id = model_info.get("conversation_id") or conversation_id
try:
result = await asyncio.wait_for(
self.run(
agent_config=agent_config,
model_config=model_info["model_config"],
message=message,
workspace_id=workspace_id,
conversation_id=model_conversation_id,
user_id=user_id,
variables=variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=web_search,
memory=memory,
files=files
),
timeout=timeout
)
finally:
# 恢复原始参数
agent_config.model_parameters = original_params
elapsed = time.time() - start_time
usage = result.get("usage", {})
return {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_info["label"],
"conversation_id": result['conversation_id'],
"parameters_used": model_info["parameters"],
"message": result.get("message"),
"reasoning_content": result.get("reasoning_content"),
"usage": usage,
"elapsed_time": elapsed,
"tokens_per_second": (
usage.get("completion_tokens", 0) / elapsed
if elapsed > 0 and usage.get("completion_tokens") else None
),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"),
"audio_status": result.get("audio_status"),
"citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []),
"error": None
}
except TimeoutError:
logger.warning(
"模型运行超时",
extra={
"model_config_id": str(model_info["model_config_id"]),
"timeout": timeout
}
)
return {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"conversation_id": conversation_id,
"label": model_info["label"],
"parameters_used": model_info["parameters"],
"elapsed_time": timeout,
"error": f"执行超时({timeout}秒)"
}
except Exception as e:
logger.error(
"模型运行失败",
extra={
"model_config_id": str(model_info["model_config_id"]),
"error": str(e)
}
)
return {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_info["label"],
"conversation_id": conversation_id,
"parameters_used": model_info["parameters"],
"elapsed_time": 0,
"error": str(e)
}
# 执行所有模型(并行或串行)
if parallel:
logger.debug(f"并行执行 {len(models)} 个模型")
results = await asyncio.gather(
*[run_single_model(m) for m in models],
return_exceptions=False
)
else:
logger.debug(f"串行执行 {len(models)} 个模型")
results = []
for model_info in models:
result = await run_single_model(model_info)
results.append(result)
# 统计分析
successful = [r for r in results if not r.get("error")]
failed = [r for r in results if r.get("error")]
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
cheapest = min(
successful,
key=lambda x: x.get("cost_estimate") or float("inf")
) if successful else None
logger.info(
"多模型对比完成",
extra={
"successful": len(successful),
"failed": len(failed),
"total_time": sum(r.get("elapsed_time", 0) for r in results)
}
)
return {
"results": [{
**r,
"audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
} for r in results],
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful),
"failed_count": len(failed),
"fastest_model": fastest["label"] if fastest else None,
"cheapest_model": cheapest["label"] if cheapest else None
}
def _estimate_cost(self, usage: Dict[str, Any], model_config) -> Optional[float]:
"""估算成本
Args:
usage: Token 使用情况
model_config: 模型配置
Returns:
Optional[float]: 估算成本(美元)
"""
if not usage:
return None
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
# 简化成本估算:暂时返回 None
# TODO: 实现基于模型名称或配置的成本估算
# 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段
return None
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> tuple[AgentConfig, Any]:
"""创建一个带有覆盖参数的 agent_config浅拷贝只修改 model_parameters
Args:
agent_config: 原始 Agent 配置
parameters: 要覆盖的参数
Returns:
AgentConfig: 修改后的配置(注意:这是同一个对象,只是临时修改了 model_parameters
"""
# 保存原始参数
original_params = agent_config.model_parameters
# 设置新参数
agent_config.model_parameters = parameters
return agent_config, original_params
async def run_compare_stream(
self,
*,
agent_config: AgentConfig,
models: List[Dict[str, Any]],
message: str,
workspace_id: uuid.UUID,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True,
parallel: bool = True,
timeout: int = 60,
files: list[FileInput] | None = None
) -> AsyncGenerator[str, None]:
"""多模型对比试运行(流式返回)
参考 run_compare 的实现,支持并行或串行执行
Args:
agent_config: Agent 配置
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
message: 用户消息
workspace_id: 工作空间ID
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
storage_type: 存储类型
user_rag_memory_id: RAG 记忆 ID
web_search: 是否启用网络搜索
memory: 是否启用记忆
parallel: 是否并行执行
timeout: 超时时间(秒)
files: 多模态文件
Yields:
str: SSE 格式的事件数据
"""
logger.info(
"多模型对比流式试运行",
extra={"model_count": len(models), "parallel": parallel}
)
# 提前校验文件上传
# features_config: dict = agent_config.features or {}
# if hasattr(features_config, 'model_dump'):
# features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
# 发送开始事件
yield self._format_sse_event("compare_start", {
"conversation_id": conversation_id,
"model_count": len(models),
"parallel": parallel,
"timestamp": time.time()
})
results = []
async def run_single_model_stream(idx: int, model_info: Dict[str, Any], event_queue: asyncio.Queue):
"""运行单个模型(流式)并将事件放入队列"""
model_label = model_info["label"]
model_config_id = str(model_info["model_config_id"])
# 使用模型自己的 conversation_id如果没有则使用全局的
model_conversation_id = model_info.get("conversation_id") or conversation_id
try:
# 发送模型开始事件
await event_queue.put(self._format_sse_event("model_start", {
"model_index": idx,
"model_config_id": model_config_id,
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": model_conversation_id,
"timestamp": time.time()
}))
start_time = time.time()
full_content = ""
full_reasoning = ""
returned_conversation_id = model_conversation_id
audio_url = None
audio_status = None
citations = []
suggested_questions = []
# 临时修改参数
original_params = agent_config.model_parameters
agent_config.model_parameters = model_info["parameters"]
try:
# 流式调用单个模型
async for event_str in self.run_stream(
agent_config=agent_config,
model_config=model_info["model_config"],
message=message,
workspace_id=workspace_id,
conversation_id=model_conversation_id,
user_id=user_id,
variables=variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=web_search,
memory=memory,
files=files
):
# 解析原始事件
try:
lines = event_str.strip().split('\n')
event_type = None
event_data = None
for line in lines:
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
event_data = json.loads(line[6:])
# 从 start 事件中获取实际的 conversation_id
if event_type == "start" and event_data:
conv_id = event_data.get("conversation_id")
if conv_id:
returned_conversation_id = conv_id
# 累积消息内容
if event_type == "message" and event_data:
chunk = event_data.get("content", "")
full_content += chunk
# 转发消息块事件(带模型标识)
await event_queue.put(self._format_sse_event("model_message", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"content": chunk
}))
# 转发深度思考事件(带模型标识)
if event_type == "reasoning" and event_data:
reasoning_chunk = event_data.get("content", "")
full_reasoning += reasoning_chunk
await event_queue.put(self._format_sse_event("model_reasoning", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"content": event_data.get("content", "")
}))
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
audio_status = event_data.get("audio_status")
citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", [])
if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"error": event_data.get("error", "未知错误")
}))
except Exception as e:
logger.warning(f"解析流式事件失败: {e}")
finally:
# 恢复原始参数
agent_config.model_parameters = original_params
elapsed = time.time() - start_time
# 构建结果(参考 run_compare
result = {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": returned_conversation_id,
"parameters_used": model_info["parameters"],
"message": full_content,
"reasoning_content": full_reasoning or None,
"elapsed_time": elapsed,
"audio_url": audio_url,
"audio_status": audio_status,
"citations": citations,
"suggested_questions": suggested_questions,
"error": None
}
# 发送模型完成事件
await event_queue.put(self._format_sse_event("model_end", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"elapsed_time": elapsed,
"message_length": len(full_content),
"audio_url": audio_url,
"audio_status": audio_status,
"citations": citations,
"suggested_questions": suggested_questions,
"timestamp": time.time()
}))
return result
except TimeoutError:
logger.warning(f"模型运行超时: {model_label}")
result = {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": model_conversation_id,
"parameters_used": model_info["parameters"],
"elapsed_time": timeout,
"error": f"执行超时({timeout}秒)"
}
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": model_conversation_id,
"error": result["error"],
"timestamp": time.time()
}))
return result
except Exception as e:
logger.error(f"模型运行失败: {model_label}, error: {e}")
result = {
"model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name,
"label": model_label,
"conversation_id": model_conversation_id,
"parameters_used": model_info["parameters"],
"elapsed_time": 0,
"error": str(e)
}
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": model_conversation_id,
"error": str(e),
"timestamp": time.time()
}))
return result
if parallel:
# 并行执行所有模型(参考 run_compare
logger.debug(f"并行执行 {len(models)} 个模型(流式)")
# 创建事件队列
event_queue = asyncio.Queue()
# 启动所有模型的并行任务
tasks = [
asyncio.create_task(run_single_model_stream(idx, model_info, event_queue))
for idx, model_info in enumerate(models)
]
# 持续从队列中取出事件并转发
completed_tasks = set()
while len(completed_tasks) < len(tasks):
try:
# 尝试从队列获取事件
event = await asyncio.wait_for(event_queue.get(), timeout=0.1)
yield event
except TimeoutError:
# 检查是否有任务完成
for task in tasks:
if task.done() and task not in completed_tasks:
completed_tasks.add(task)
try:
result = await task
if result:
results.append(result)
except Exception as e:
logger.error(f"获取任务结果失败: {e}")
continue
# 清空队列中剩余的事件
while not event_queue.empty():
try:
event = event_queue.get_nowait()
yield event
except asyncio.QueueEmpty:
break
else:
# 串行执行每个模型(参考 run_compare
logger.debug(f"串行执行 {len(models)} 个模型(流式)")
for idx, model_info in enumerate(models):
# 创建临时队列用于单个模型
event_queue = asyncio.Queue()
# 运行单个模型
result = await run_single_model_stream(idx, model_info, event_queue)
if result:
results.append(result)
# 转发该模型的所有事件
while not event_queue.empty():
try:
event = event_queue.get_nowait()
yield event
except asyncio.QueueEmpty:
break
# 统计分析(参考 run_compare
successful = [r for r in results if not r.get("error")]
failed = [r for r in results if r.get("error")]
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
cheapest = min(
successful,
key=lambda x: x.get("cost_estimate") or float("inf")
) if successful else None
# 构建结果摘要(包含完整的 message
results_summary = []
for r in results:
results_summary.append({
"model_config_id": str(r["model_config_id"]),
"model_name": r["model_name"],
"label": r["label"],
"conversation_id": r.get("conversation_id"),
"message": r.get("message"),
"reasoning_content": r.get("reasoning_content"),
"elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
"error": r.get("error")
})
# 发送对比完成事件(参考 run_compare 的返回格式)
yield self._format_sse_event("compare_end", {
"conversation_id": conversation_id,
"results": results_summary, # 包含完整结果
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful),
"failed_count": len(failed),
"fastest_model": fastest["label"] if fastest else None,
"cheapest_model": cheapest["label"] if cheapest else None,
"timestamp": time.time()
})
logger.info(
"多模型对比流式完成",
extra={
"successful": len(successful),
"failed": len(failed),
"total_time": sum(r.get("elapsed_time", 0) for r in results)
}
)