- Replace plain image URLs with `<img src="..." data-url="...">` HTML tags in multimodal and document extractor services - Propagate citations from workflow end events to client responses - Update system prompts to instruct LLMs to render images using Markdown `` with strict UUID-preserving URL copying
2499 lines
101 KiB
Python
2499 lines
101 KiB
Python
"""
|
||
试运行服务
|
||
|
||
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
|
||
"""
|
||
import asyncio
|
||
import datetime
|
||
import json
|
||
import time
|
||
import uuid
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||
|
||
from langchain.agents import create_agent
|
||
from langchain.tools import tool
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.agent.agent_middleware import AgentMiddleware
|
||
from app.core.agent.langchain_agent import LangChainAgent
|
||
from app.core.config import settings
|
||
from app.core.error_codes import BizCode
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.logging_config import get_business_logger
|
||
from app.core.memory.enums import SearchStrategy
|
||
from app.core.memory.memory_service import MemoryService
|
||
from app.core.rag.nlp.search import knowledge_retrieval
|
||
from app.db import get_db_context
|
||
from app.models import AgentConfig, ModelConfig
|
||
from app.repositories.tool_repository import ToolRepository
|
||
from app.schemas.app_schema import FileInput, Citation, FileType
|
||
from app.schemas.model_schema import ModelInfo
|
||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||
from app.services.conversation_service import ConversationService
|
||
from app.services.langchain_tool_server import Search
|
||
from app.services.model_parameter_merger import ModelParameterMerger
|
||
from app.services.model_service import ModelApiKeyService
|
||
from app.services.multimodal_service import MultimodalService
|
||
from app.services.tool_service import ToolService
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
class KnowledgeRetrievalInput(BaseModel):
|
||
"""知识库检索工具输入参数"""
|
||
query: str = Field(description="需要检索的问题或关键词")
|
||
|
||
|
||
class WebSearchInput(BaseModel):
|
||
"""网络搜索工具输入参数"""
|
||
query: str = Field(description="需要搜索的问题或关键词")
|
||
|
||
|
||
class LongTermMemoryInput(BaseModel):
|
||
"""长期记忆工具输入参数"""
|
||
question: str = Field(
|
||
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||
|
||
|
||
def create_long_term_memory_tool(
|
||
memory_config: Dict[str, Any],
|
||
end_user_id: str,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None
|
||
):
|
||
"""创建记忆工具,
|
||
|
||
|
||
Args:
|
||
memory_config: 记忆配置
|
||
end_user_id: 用户ID
|
||
storage_type: 存储类型(可选)
|
||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||
|
||
Returns:
|
||
长期记忆工具
|
||
"""
|
||
# search_switch = memory_config.get("search_switch", "2")
|
||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
|
||
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||
|
||
@tool(args_schema=LongTermMemoryInput)
|
||
def long_term_memory(question: str) -> str:
|
||
"""
|
||
从用户的历史记忆中检索相关信息。用于了解用户的背景、偏好和历史对话内容。
|
||
|
||
**何时使用此工具:**
|
||
- 用户明确询问历史信息(如"我之前说过什么"、"上次我们聊了什么")
|
||
- 用户询问个人信息或偏好(如"我喜欢什么"、"我的习惯是什么")
|
||
- 需要基于历史上下文提供个性化建议
|
||
|
||
**何时不使用此工具:**
|
||
- 简单问候(如"你好"、"谢谢"、"再见")
|
||
- 纯任务性请求(如"写代码"、"翻译文字"、"分析图片")
|
||
- 用户已提供完整信息(如提供了文本、图片、文档等内容)
|
||
- 创作性任务(如"写诗"、"编故事"、"创作谜语")
|
||
|
||
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
|
||
|
||
Args:
|
||
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词,第三人称描述的偏好、行为通常指用户本人,比如(我,本人,在下,自己,咱,鄙人,吴,余)通指用户)
|
||
|
||
Returns:
|
||
检索到的历史记忆内容
|
||
"""
|
||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||
try:
|
||
with get_db_context() as db:
|
||
memory_service = MemoryService(db, config_id, end_user_id)
|
||
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
|
||
|
||
# memory_content = asyncio.run(
|
||
# MemoryAgentService().read_memory(
|
||
# end_user_id=end_user_id,
|
||
# message=question,
|
||
# history=[],
|
||
# search_switch="2",
|
||
# config_id=config_id,
|
||
# db=db,
|
||
# storage_type=storage_type,
|
||
# user_rag_memory_id=user_rag_memory_id
|
||
# )
|
||
# )
|
||
# task = celery_app.send_task(
|
||
# "app.core.memory.agent.read_message",
|
||
# args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||
# )
|
||
# result = task_service.get_task_memory_read_result(task.id)
|
||
# status = result.get("status")
|
||
# logger.info(f"读取任务状态:{status}")
|
||
# if memory_content:
|
||
# memory_content = memory_content['answer']
|
||
# logger.info(f'用户ID:Agent:{end_user_id}')
|
||
# logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||
#
|
||
# logger.info(
|
||
# "长期记忆检索成功",
|
||
# extra={
|
||
# "end_user_id": end_user_id,
|
||
# "content_length": len(str(memory_content))
|
||
# }
|
||
# )
|
||
return f"检索到以下历史记忆:\n\n{search_result.content}"
|
||
except Exception as e:
|
||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
return f"记忆检索失败: {str(e)}"
|
||
|
||
return long_term_memory
|
||
|
||
|
||
def create_web_search_tool(web_search_config: Dict[str, Any]):
|
||
"""创建网络搜索工具
|
||
|
||
Args:
|
||
web_search_config: 网络搜索配置
|
||
|
||
Returns:
|
||
网络搜索工具
|
||
"""
|
||
logger.info("创建网络搜索工具")
|
||
|
||
@tool(args_schema=WebSearchInput)
|
||
def web_search_tool(query: str) -> str:
|
||
"""从互联网搜索最新信息。当用户的问题需要实时信息、最新新闻或网络资料时,使用此工具进行搜索。
|
||
|
||
Args:
|
||
query: 需要搜索的问题或关键词
|
||
|
||
Returns:
|
||
搜索到的相关网络信息
|
||
"""
|
||
try:
|
||
logger.info(f"执行网络搜索: {query}")
|
||
|
||
# 调用搜索服务
|
||
search_result = Search(query)
|
||
logger.info(
|
||
"网络搜索成功",
|
||
extra={
|
||
"query": query,
|
||
"result_length": len(search_result)
|
||
}
|
||
)
|
||
|
||
return f"搜索到以下网络信息:\n\n{search_result}"
|
||
|
||
except Exception as e:
|
||
logger.error("网络搜索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
return f"搜索失败: {str(e)}"
|
||
|
||
return web_search_tool
|
||
|
||
|
||
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
|
||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||
|
||
Args:
|
||
kb_config: 知识库配置
|
||
kb_ids: 知识库ID列表
|
||
user_id: 用户ID
|
||
citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充)
|
||
列表元素类型为 Citation,包含字段:
|
||
- document_id: 文档唯一标识
|
||
- file_name: 文件名
|
||
- knowledge_id: 知识库 ID
|
||
- score: 检索相关性得分
|
||
|
||
Returns:
|
||
检索到的相关知识内容
|
||
"""
|
||
logger.info(f"创建知识库检索工具,用户:{user_id}")
|
||
|
||
@tool(args_schema=KnowledgeRetrievalInput)
|
||
def knowledge_retrieval_tool(query: str) -> str:
|
||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||
|
||
Args:
|
||
query: 需要检索的问题或关键词
|
||
|
||
Returns:
|
||
检索到的相关知识内容
|
||
"""
|
||
|
||
try:
|
||
|
||
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
|
||
if retrieve_chunks_result:
|
||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||
context = '\n\n'.join(retrieval_knowledge)
|
||
logger.info(
|
||
"知识库检索成功",
|
||
extra={
|
||
"kb_ids": kb_ids,
|
||
"result_count": len(retrieval_knowledge),
|
||
"total_length": len(context)
|
||
}
|
||
)
|
||
|
||
# 收集引用信息
|
||
if citations_collector is not None:
|
||
seen_doc_ids = {c.get("document_id") for c in citations_collector}
|
||
for chunk in retrieve_chunks_result:
|
||
meta = chunk.metadata or {}
|
||
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||
if doc_id and doc_id not in seen_doc_ids:
|
||
seen_doc_ids.add(doc_id)
|
||
citations_collector.append(Citation(
|
||
document_id=doc_id,
|
||
file_name=meta.get("file_name", ""),
|
||
knowledge_id=str(meta.get("knowledge_id", "")),
|
||
score=meta.get("score", 0)
|
||
))
|
||
|
||
return f"检索到以下相关信息:\n\n{context}"
|
||
else:
|
||
logger.warning("知识库检索未找到结果")
|
||
return "未找到相关信息"
|
||
except Exception as e:
|
||
logger.error("知识库检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
return f"检索失败: {str(e)}"
|
||
|
||
return knowledge_retrieval_tool
|
||
|
||
|
||
class AgentRunService:
|
||
"""Agent运行服务类"""
|
||
|
||
def __init__(self, db: Session):
|
||
"""Agent运行服务
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
"""
|
||
self.db = db
|
||
|
||
@staticmethod
|
||
def prepare_variables(
|
||
input_vars: dict | None,
|
||
variables_config: dict
|
||
) -> dict:
|
||
input_vars = input_vars or {}
|
||
for variable in variables_config:
|
||
if variable.get("required") and variable.get("name") not in input_vars:
|
||
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
|
||
return input_vars
|
||
|
||
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||
"""加载工具配置"""
|
||
tools = []
|
||
if web_search:
|
||
search_tool = create_web_search_tool({})
|
||
tools.append(search_tool)
|
||
if not tools_config:
|
||
return tools
|
||
tool_service = ToolService(self.db)
|
||
|
||
if tools_config and isinstance(tools_config, list):
|
||
for tool_config in tools_config:
|
||
if tool_config.get("enabled", False):
|
||
# 根据工具名称查找工具实例
|
||
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||
if tool_instance:
|
||
# 转换为LangChain工具
|
||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||
tools.append(langchain_tool)
|
||
logger.debug(
|
||
"已添加网络搜索工具",
|
||
extra={
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
return tools
|
||
|
||
def load_skill_config(
|
||
self,
|
||
skills_config: dict | None,
|
||
message: str, tenant_id
|
||
) -> tuple[list, str]:
|
||
if not skills_config:
|
||
return [], ""
|
||
|
||
tools = []
|
||
skill_prompts = ""
|
||
skill_enable = skills_config.get("enabled", False)
|
||
if skill_enable:
|
||
middleware = AgentMiddleware(skills=skills_config)
|
||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||
tools.extend(skill_tools)
|
||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||
|
||
if skill_configs:
|
||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||
tool_to_skill_map)
|
||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||
skill_prompts = AgentMiddleware.get_active_prompts(
|
||
activated_skill_ids, skill_configs
|
||
)
|
||
|
||
return tools, skill_prompts
|
||
|
||
def load_knowledge_retrieval_config(
|
||
self,
|
||
knowledge_retrieval_config: dict | None,
|
||
user_id
|
||
) -> tuple[list, list]:
|
||
"""返回 (tools, citations_collector)"""
|
||
if not knowledge_retrieval_config:
|
||
return [], []
|
||
|
||
citations_collector = []
|
||
tools = []
|
||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||
kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
|
||
if kb_ids:
|
||
kb_tool = create_knowledge_retrieval_tool(
|
||
knowledge_retrieval_config, kb_ids, user_id,
|
||
citations_collector=citations_collector
|
||
)
|
||
tools.append(kb_tool)
|
||
logger.debug(
|
||
"已添加知识库检索工具",
|
||
extra={"kb_ids": kb_ids, "tool_count": len(tools)}
|
||
)
|
||
return tools, citations_collector
|
||
|
||
def load_memory_config(
|
||
self,
|
||
memory_config: dict | None,
|
||
user_id,
|
||
storage_type,
|
||
user_rag_memory_id
|
||
) -> tuple[list, bool]:
|
||
"""加载长期记忆配置"""
|
||
if not memory_config:
|
||
return [], False
|
||
|
||
tools = []
|
||
if memory_config.get("enabled"):
|
||
if user_id:
|
||
# 创建长期记忆工具
|
||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||
user_rag_memory_id)
|
||
tools.append(memory_tool)
|
||
|
||
logger.debug(
|
||
"已添加长期记忆工具",
|
||
extra={
|
||
"user_id": user_id,
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
return tools, bool(memory_config.get("enabled"))
|
||
|
||
@staticmethod
|
||
def _validate_file_upload(
|
||
features_config: Dict[str, Any],
|
||
files: Optional[List[FileInput]]
|
||
) -> None:
|
||
"""校验上传文件是否符合 file_upload 配置"""
|
||
if not files or not features_config:
|
||
return
|
||
fu = features_config.get("file_upload", {})
|
||
if not (isinstance(fu, dict) and fu.get("enabled")):
|
||
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
|
||
max_count = fu.get("max_file_count", 5)
|
||
if len(files) > max_count:
|
||
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
|
||
|
||
# 校验传输方式
|
||
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
|
||
for f in files:
|
||
if f.transfer_method.value not in allowed_methods:
|
||
raise BusinessException(
|
||
f"不支持的文件传输方式:{f.transfer_method.value},允许的方式:{', '.join(allowed_methods)}",
|
||
BizCode.BAD_REQUEST
|
||
)
|
||
|
||
# 各类型对应的开关和大小限制配置键
|
||
type_cfg = {
|
||
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
|
||
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
|
||
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
|
||
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
|
||
}
|
||
|
||
for f in files:
|
||
ftype = str(f.type) # 如 "image", "audio", "document", "video"
|
||
cfg = type_cfg.get(ftype)
|
||
if cfg is None:
|
||
continue
|
||
enabled_key, size_key, default_max_mb, label = cfg
|
||
|
||
# 校验类型开关
|
||
if not fu.get(enabled_key):
|
||
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
|
||
|
||
# 校验文件大小(仅当内容已加载时)
|
||
content = f.get_content()
|
||
if content is not None:
|
||
max_mb = fu.get(size_key, default_max_mb)
|
||
size_mb = len(content) / (1024 * 1024)
|
||
if size_mb > max_mb:
|
||
raise BusinessException(
|
||
f"{label}文件大小超过限制(最大 {max_mb}MB,当前 {size_mb:.1f}MB)",
|
||
BizCode.BAD_REQUEST
|
||
)
|
||
|
||
@staticmethod
|
||
def _get_opening_statement(
|
||
features_config: Dict[str, Any],
|
||
is_new_conversation: bool,
|
||
variables: Optional[Dict[str, Any]] = None
|
||
) -> tuple[Any, Any]:
|
||
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
|
||
if not is_new_conversation:
|
||
return None, None
|
||
opening = features_config.get("opening_statement", {})
|
||
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||
return None, None
|
||
|
||
statement = opening["statement"]
|
||
suggested_questions = opening["suggested_questions"]
|
||
|
||
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
||
if variables:
|
||
for var_name, var_value in variables.items():
|
||
placeholder = f"{{{{{var_name}}}}}"
|
||
statement = statement.replace(placeholder, str(var_value))
|
||
|
||
return statement, suggested_questions
|
||
|
||
@staticmethod
|
||
def _filter_citations(
|
||
features_config: Dict[str, Any],
|
||
citations: List[Citation]
|
||
) -> List[Any]:
|
||
"""根据 citation 开关决定是否返回引用来源,并根据 allow_download 附加下载链接"""
|
||
citation_cfg = features_config.get("citation", {})
|
||
if not (isinstance(citation_cfg, dict) and citation_cfg.get("enabled")):
|
||
return []
|
||
allow_download = citation_cfg.get("allow_download", False)
|
||
result = []
|
||
for cit in citations:
|
||
item = cit.model_dump() if hasattr(cit, "model_dump") else dict(cit)
|
||
if allow_download and item.get("document_id"):
|
||
from app.core.config import settings
|
||
item["download_url"] = f"{settings.FILE_LOCAL_SERVER_URL}/apps/citations/{item['document_id']}/download"
|
||
result.append(item)
|
||
return result
|
||
|
||
async def run(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
model_config: ModelConfig,
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True,
|
||
memory: bool = True,
|
||
sub_agent: bool = False,
|
||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||
) -> Dict[str, Any]:
|
||
"""执行试运行(使用 LangChain Agent)
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
model_config: 模型配置
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID(必须,用于会话隔离)
|
||
conversation_id: 会话ID(用于多轮对话)
|
||
user_id: 用户ID
|
||
variables: 自定义变量参数值
|
||
storage_type: 存储类型(可选)
|
||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||
web_search: 是否启用网络搜索(默认True)
|
||
memory: 是否启用长期记忆(默认True)
|
||
sub_agent: 是否为子代理调用(默认False)
|
||
files: 多模态文件列表(可选)
|
||
|
||
Returns:
|
||
Dict: 包含 AI 回复和元数据的字典
|
||
"""
|
||
start_time = time.time()
|
||
tools_config: dict | list | None = agent_config.tools
|
||
skills_config: dict | None = agent_config.skills
|
||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||
memory_config: dict | None = agent_config.memory
|
||
features_config: dict = agent_config.features or {}
|
||
|
||
# 从 features 中读取功能开关(优先级高于参数默认值)
|
||
web_search_feature = features_config.get("web_search", {})
|
||
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
|
||
web_search = False
|
||
|
||
# file_upload 校验
|
||
self._validate_file_upload(features_config, files)
|
||
|
||
try:
|
||
# 1. 获取 API Key 配置
|
||
api_key_config = await self._get_api_key(model_config.id)
|
||
logger.debug(
|
||
"API Key 配置获取成功",
|
||
extra={
|
||
"model_name": api_key_config["model_name"],
|
||
"has_api_key": bool(api_key_config["api_key"]),
|
||
"has_api_base": bool(api_key_config.get("api_base"))
|
||
}
|
||
)
|
||
|
||
# 2. 合并模型参数
|
||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||
model_config=model_config,
|
||
agent_config=agent_config
|
||
)
|
||
|
||
if sub_agent:
|
||
variables = self.prepare_variables(variables, agent_config.variables)
|
||
else:
|
||
# FIXME: subagent input valid
|
||
variables = variables or {}
|
||
|
||
system_prompt = render_prompt_message(
|
||
agent_config.system_prompt,
|
||
PromptMessageRole.USER,
|
||
variables
|
||
)
|
||
|
||
# 3. 处理系统提示词(支持变量替换)
|
||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||
|
||
# 4. 准备工具列表
|
||
tools = []
|
||
|
||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||
|
||
# 从配置中获取启用的工具
|
||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||
tools.extend(skill_tools)
|
||
if skill_prompts:
|
||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
|
||
tools.extend(kb_tools)
|
||
# 添加长期记忆工具
|
||
memory_flag = False
|
||
if memory:
|
||
memory_tools, memory_flag = self.load_memory_config(
|
||
memory_config, user_id, storage_type, user_rag_memory_id
|
||
)
|
||
tools.extend(memory_tools)
|
||
|
||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||
is_new_conversation = not conversation_id
|
||
opening, suggested_questions = None, None
|
||
if not sub_agent:
|
||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||
conversation_id = await self._ensure_conversation(
|
||
conversation_id=conversation_id,
|
||
app_id=agent_config.app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
opening_statement=opening,
|
||
suggested_questions=suggested_questions
|
||
)
|
||
|
||
model_info = ModelInfo(
|
||
model_name=api_key_config["model_name"],
|
||
provider=api_key_config["provider"],
|
||
api_key=api_key_config["api_key"],
|
||
api_base=api_key_config["api_base"],
|
||
capability=api_key_config["capability"],
|
||
is_omni=api_key_config["is_omni"],
|
||
model_type=model_config.type
|
||
)
|
||
|
||
# 6. 加载历史消息(包含开场白)
|
||
history = await self._load_conversation_history(
|
||
conversation_id=conversation_id,
|
||
max_history=10,
|
||
current_provider=api_key_config.get("provider"),
|
||
current_is_omni=api_key_config.get("is_omni", False)
|
||
)
|
||
|
||
# 6. 处理多模态文件
|
||
processed_files = None
|
||
has_doc_with_images = False
|
||
if files:
|
||
provider = api_key_config.get("provider", "openai")
|
||
multimodal_service = MultimodalService(self.db, model_info)
|
||
fu_config = features_config.get("file_upload", {})
|
||
if hasattr(fu_config, "model_dump"):
|
||
fu_config = fu_config.model_dump()
|
||
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
|
||
processed_files = await multimodal_service.process_files(
|
||
files, document_image_recognition=doc_img_recognition,
|
||
workspace_id=workspace_id
|
||
)
|
||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||
capability = api_key_config.get("capability", [])
|
||
has_doc_with_images = (
|
||
doc_img_recognition
|
||
and "vision" in capability
|
||
and any(f.type == FileType.DOCUMENT for f in files)
|
||
)
|
||
if has_doc_with_images:
|
||
system_prompt += (
|
||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||
"请在回答中用 Markdown 格式  展示对应图片。"
|
||
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||
)
|
||
|
||
agent = LangChainAgent(
|
||
model_name=api_key_config["model_name"],
|
||
api_key=api_key_config["api_key"],
|
||
provider=api_key_config.get("provider", "openai"),
|
||
api_base=api_key_config.get("api_base"),
|
||
is_omni=api_key_config.get("is_omni", False),
|
||
temperature=effective_params.get("temperature", 0.7),
|
||
max_tokens=effective_params.get("max_tokens", 2000),
|
||
system_prompt=system_prompt,
|
||
tools=tools,
|
||
deep_thinking=effective_params.get("deep_thinking", False),
|
||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||
json_output=effective_params.get("json_output", False),
|
||
capability=api_key_config.get("capability", []),
|
||
)
|
||
|
||
# 为需要运行时上下文的工具注入上下文
|
||
for t in tools:
|
||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||
t.tool_instance.set_runtime_context(
|
||
user_id=user_id or "anonymous",
|
||
conversation_id=str(conversation_id) if conversation_id else None,
|
||
uploaded_files=processed_files or []
|
||
)
|
||
# 7. 知识库检索
|
||
context = None
|
||
|
||
logger.debug(
|
||
"准备调用 LangChain Agent",
|
||
extra={
|
||
"model": api_key_config["model_name"],
|
||
"has_history": bool(history),
|
||
"has_context": bool(context),
|
||
"has_files": bool(processed_files)
|
||
}
|
||
)
|
||
|
||
memory_config_ = agent_config.memory
|
||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||
|
||
# 8. 调用 Agent(支持多模态)
|
||
result = await agent.chat(
|
||
message=message,
|
||
history=history,
|
||
context=context,
|
||
files=processed_files # 传递处理后的文件
|
||
)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||
|
||
# 9. 生成 TTS audio_url(在保存消息前生成,以便一并存入 meta_data)
|
||
audio_url = await self._generate_tts(
|
||
features_config, result["content"], api_key_config,
|
||
tenant_id=tenant_id, workspace_id=workspace_id
|
||
) if not sub_agent else None
|
||
|
||
# 过滤 citations(只调用一次)
|
||
filtered_citations = self._filter_citations(features_config, citations_collector)
|
||
|
||
# 10. 保存会话消息
|
||
if not sub_agent:
|
||
await self._save_conversation_message(
|
||
conversation_id=conversation_id,
|
||
user_message=message,
|
||
assistant_message=result["content"],
|
||
app_id=agent_config.app_id,
|
||
user_id=user_id,
|
||
meta_data={
|
||
"usage": result.get("usage", {
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
"total_tokens": 0
|
||
}),
|
||
"reasoning_content": result.get("reasoning_content")
|
||
},
|
||
files=files,
|
||
processed_files=processed_files,
|
||
audio_url=audio_url,
|
||
citations=filtered_citations,
|
||
provider=api_key_config.get("provider"),
|
||
is_omni=api_key_config.get("is_omni", False)
|
||
)
|
||
|
||
response = {
|
||
"message": result["content"],
|
||
"reasoning_content": result.get("reasoning_content"),
|
||
"conversation_id": conversation_id,
|
||
"usage": result.get("usage", {
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
"total_tokens": 0
|
||
}),
|
||
"elapsed_time": elapsed_time,
|
||
"suggested_questions": await self._generate_suggested_questions(
|
||
features_config, result["content"], api_key_config, effective_params
|
||
) if not sub_agent else [],
|
||
"citations": filtered_citations,
|
||
"audio_url": audio_url,
|
||
"audio_status": "pending" if audio_url else None
|
||
}
|
||
|
||
logger.info(
|
||
"试运行完成",
|
||
extra={
|
||
"model": model_config.name,
|
||
"elapsed_time": elapsed_time,
|
||
"message_length": len(result["content"]),
|
||
"total_tokens": result.get("usage", {}).get("total_tokens", 0)
|
||
}
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error("LangChain Agent 调用失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
||
|
||
async def run_stream(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
model_config: ModelConfig,
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True, # 布尔类型默认值
|
||
memory: bool = True, # 布尔类型默认值
|
||
sub_agent: bool = False, # 是否是作为子Agent运行
|
||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||
|
||
) -> AsyncGenerator[str, None]:
|
||
"""执行试运行(流式返回,使用 LangChain Agent)
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
model_config: 模型配置
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID(必须,用于会话隔离)
|
||
conversation_id: 会话ID(用于多轮对话)
|
||
user_id: 用户ID
|
||
variables: 自定义变量参数值
|
||
|
||
Yields:
|
||
str: SSE 格式的事件数据
|
||
"""
|
||
tools_config: dict | list | None = agent_config.tools
|
||
skills_config: dict | None = agent_config.skills
|
||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||
memory_config: dict | None = agent_config.memory
|
||
features_config: dict = agent_config.features or {}
|
||
|
||
# 从 features 中读取功能开关
|
||
web_search_feature = features_config.get("web_search", {})
|
||
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||
web_search = False
|
||
|
||
# file_upload 校验
|
||
self._validate_file_upload(features_config, files)
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 1. 获取 API Key 配置
|
||
api_key_config = await self._get_api_key(model_config.id)
|
||
if not sub_agent:
|
||
variables = self.prepare_variables(variables, agent_config.variables)
|
||
else:
|
||
# FIXME: subagent input valid
|
||
variables = variables or {}
|
||
|
||
# 2. 合并模型参数
|
||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||
model_config=model_config,
|
||
agent_config=agent_config
|
||
)
|
||
|
||
items_params = variables
|
||
|
||
system_prompt = render_prompt_message(
|
||
agent_config.system_prompt, # 修正拼写错误
|
||
PromptMessageRole.USER,
|
||
items_params
|
||
)
|
||
|
||
# 3. 处理系统提示词(支持变量替换)
|
||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||
|
||
# 4. 准备工具列表
|
||
tools = []
|
||
|
||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||
|
||
# 从配置中获取启用的工具
|
||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||
tools.extend(skill_tools)
|
||
if skill_prompts:
|
||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
|
||
tools.extend(kb_tools)
|
||
|
||
# 添加长期记忆工具
|
||
memory_flag = False
|
||
if memory:
|
||
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
|
||
user_rag_memory_id)
|
||
tools.extend(memory_tools)
|
||
|
||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||
is_new_conversation = not conversation_id
|
||
opening, suggested_questions = None, None
|
||
if not sub_agent:
|
||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||
conversation_id = await self._ensure_conversation(
|
||
conversation_id=conversation_id,
|
||
app_id=agent_config.app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
sub_agent=sub_agent,
|
||
opening_statement=opening,
|
||
suggested_questions=suggested_questions
|
||
)
|
||
|
||
model_info = ModelInfo(
|
||
model_name=api_key_config["model_name"],
|
||
provider=api_key_config["provider"],
|
||
api_key=api_key_config["api_key"],
|
||
api_base=api_key_config["api_base"],
|
||
capability=api_key_config["capability"],
|
||
is_omni=api_key_config["is_omni"],
|
||
model_type=model_config.type
|
||
)
|
||
|
||
# 6. 加载历史消息
|
||
history = await self._load_conversation_history(
|
||
conversation_id=conversation_id,
|
||
max_history=memory_config.get("max_history", 10),
|
||
current_provider=api_key_config.get("provider"),
|
||
current_is_omni=api_key_config.get("is_omni", False)
|
||
)
|
||
|
||
# 6. 处理多模态文件
|
||
processed_files = None
|
||
has_doc_with_images = False
|
||
if files:
|
||
provider = api_key_config.get("provider", "openai")
|
||
multimodal_service = MultimodalService(self.db, model_info)
|
||
fu_config = features_config.get("file_upload", {})
|
||
if hasattr(fu_config, "model_dump"):
|
||
fu_config = fu_config.model_dump()
|
||
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
|
||
processed_files = await multimodal_service.process_files(
|
||
files, document_image_recognition=doc_img_recognition,
|
||
workspace_id=workspace_id
|
||
)
|
||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||
capability = api_key_config.get("capability", [])
|
||
has_doc_with_images = (
|
||
doc_img_recognition
|
||
and "vision" in capability
|
||
and any(f.type == FileType.DOCUMENT for f in files)
|
||
)
|
||
if has_doc_with_images:
|
||
system_prompt += (
|
||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||
"请在回答中用 Markdown 格式  展示对应图片。"
|
||
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||
)
|
||
|
||
# 创建 LangChain Agent
|
||
agent = LangChainAgent(
|
||
model_name=api_key_config["model_name"],
|
||
api_key=api_key_config["api_key"],
|
||
provider=api_key_config.get("provider", "openai"),
|
||
api_base=api_key_config.get("api_base"),
|
||
is_omni=api_key_config.get("is_omni", False),
|
||
temperature=effective_params.get("temperature", 0.7),
|
||
max_tokens=effective_params.get("max_tokens", 2000),
|
||
system_prompt=system_prompt,
|
||
tools=tools,
|
||
streaming=True,
|
||
deep_thinking=effective_params.get("deep_thinking", False),
|
||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||
json_output=effective_params.get("json_output", False),
|
||
capability=api_key_config.get("capability", []),
|
||
)
|
||
|
||
# 为需要运行时上下文的工具注入上下文
|
||
for t in tools:
|
||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||
t.tool_instance.set_runtime_context(
|
||
user_id=user_id or "anonymous",
|
||
conversation_id=str(conversation_id) if conversation_id else None,
|
||
uploaded_files=processed_files or []
|
||
)
|
||
# 7. 知识库检索
|
||
context = None
|
||
|
||
# 8. 发送开始事件
|
||
yield self._format_sse_event("start", {
|
||
"conversation_id": conversation_id,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
memory_config_ = agent_config.memory
|
||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||
|
||
# 9. 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||
full_content = ""
|
||
full_reasoning = ""
|
||
total_tokens = 0
|
||
|
||
# 启动流式 TTS(文本边输出边合成)
|
||
text_queue: asyncio.Queue = asyncio.Queue()
|
||
stream_audio_url, tts_task = await self._generate_tts_streaming(
|
||
features_config, api_key_config,
|
||
text_queue=text_queue,
|
||
tenant_id=tenant_id, workspace_id=workspace_id
|
||
) if not sub_agent else (None, None)
|
||
|
||
async for chunk in agent.chat_stream(
|
||
message=message,
|
||
history=history,
|
||
context=context,
|
||
files=processed_files
|
||
):
|
||
if isinstance(chunk, int):
|
||
total_tokens = chunk
|
||
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
|
||
full_reasoning += chunk["content"]
|
||
yield self._format_sse_event("reasoning", {"content": chunk["content"]})
|
||
else:
|
||
full_content += chunk
|
||
yield self._format_sse_event("message", {"content": chunk})
|
||
if tts_task is not None:
|
||
await text_queue.put(chunk)
|
||
|
||
# 文本结束,通知 TTS
|
||
if tts_task is not None:
|
||
await text_queue.put(None)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||
|
||
if sub_agent:
|
||
yield self._format_sse_event("sub_usage", {"total_tokens": total_tokens})
|
||
|
||
# 过滤 citations(只调用一次)
|
||
filtered_citations = self._filter_citations(features_config, citations_collector)
|
||
|
||
# 11. 保存会话消息
|
||
if not sub_agent:
|
||
await self._save_conversation_message(
|
||
conversation_id=conversation_id,
|
||
user_message=message,
|
||
assistant_message=full_content,
|
||
app_id=agent_config.app_id,
|
||
user_id=user_id,
|
||
meta_data={
|
||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
|
||
"reasoning_content": full_reasoning or None
|
||
},
|
||
files=files,
|
||
processed_files=processed_files,
|
||
audio_url=stream_audio_url,
|
||
citations=filtered_citations,
|
||
provider=api_key_config.get("provider"),
|
||
is_omni=api_key_config.get("is_omni", False)
|
||
)
|
||
|
||
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status)
|
||
end_data: Dict[str, Any] = {
|
||
"conversation_id": conversation_id,
|
||
"elapsed_time": elapsed_time,
|
||
"message_length": len(full_content)
|
||
}
|
||
if not sub_agent:
|
||
end_data["suggested_questions"] = await self._generate_suggested_questions(
|
||
features_config, full_content, api_key_config, effective_params
|
||
)
|
||
end_data["audio_url"] = stream_audio_url
|
||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||
audio_status = "pending"
|
||
if tts_task is not None and tts_task.done():
|
||
# 任务已完成,检查是否有异常
|
||
try:
|
||
tts_task.result()
|
||
audio_status = "completed"
|
||
except Exception as e:
|
||
logger.warning(f"TTS任务异常: {e}")
|
||
audio_status = "failed"
|
||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||
end_data["citations"] = filtered_citations
|
||
yield self._format_sse_event("end", end_data)
|
||
|
||
logger.info(
|
||
"流式试运行完成",
|
||
extra={
|
||
"model": model_config.name,
|
||
"elapsed_time": elapsed_time,
|
||
"message_length": len(full_content)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("流式 Agent 调用失败", extra={"error": str(e)}, exc_info=True)
|
||
# 发送错误事件
|
||
yield self._format_sse_event("error", {
|
||
"error": str(e),
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str:
|
||
"""格式化 SSE 事件
|
||
|
||
Args:
|
||
event: 事件类型
|
||
data: 事件数据
|
||
|
||
Returns:
|
||
str: SSE 格式的字符串
|
||
"""
|
||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
|
||
"""获取模型的 API Key
|
||
|
||
Args:
|
||
model_config_id: 模型配置ID
|
||
|
||
Returns:
|
||
Dict: 包含 model_name, api_key, api_base 的字典
|
||
|
||
Raises:
|
||
BusinessException: 当没有可用的 API Key 时
|
||
"""
|
||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||
# stmt = (
|
||
# select(ModelApiKey).join(
|
||
# ModelConfig, ModelApiKey.model_configs
|
||
# )
|
||
# .where(
|
||
# ModelConfig.id == model_config_id,
|
||
# ModelApiKey.is_active.is_(True)
|
||
# )
|
||
# .order_by(ModelApiKey.priority.desc())
|
||
# .limit(1)
|
||
# )
|
||
#
|
||
# api_key = self.db.scalars(stmt).first()
|
||
# api_key = api_keys[0] if api_keys else None
|
||
api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||
|
||
if not api_key:
|
||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||
|
||
return {
|
||
"model_name": api_key.model_name,
|
||
"provider": api_key.provider,
|
||
"api_key": api_key.api_key,
|
||
"api_base": api_key.api_base,
|
||
"api_key_id": api_key.id,
|
||
"is_omni": api_key.is_omni,
|
||
"capability": api_key.capability
|
||
}
|
||
|
||
async def _ensure_conversation(
|
||
self,
|
||
conversation_id: Optional[str],
|
||
app_id: uuid.UUID,
|
||
workspace_id: uuid.UUID,
|
||
user_id: Optional[str],
|
||
sub_agent: bool = False,
|
||
opening_statement: Optional[str] = None,
|
||
suggested_questions: Optional[List[str]] = None
|
||
) -> str:
|
||
"""确保会话存在(创建或验证)
|
||
|
||
Args:
|
||
conversation_id: 会话ID(可选)
|
||
app_id: 应用ID
|
||
workspace_id: 工作空间ID(必须)
|
||
user_id: 用户ID
|
||
sub_agent: 是否为子代理
|
||
opening_statement: 开场白(新会话时作为第一条消息写入)
|
||
suggested_questions: 预设问题列表
|
||
|
||
Returns:
|
||
str: 会话ID
|
||
|
||
Raises:
|
||
BusinessException: 当指定的会话不存在时
|
||
"""
|
||
from app.models import Conversation as ConversationModel
|
||
from app.services.conversation_service import ConversationService
|
||
|
||
conversation_service = ConversationService(self.db)
|
||
|
||
# 如果没有提供会话ID,创建新会话
|
||
if not conversation_id:
|
||
logger.info(
|
||
"创建新的草稿会话",
|
||
extra={"workspace_id": str(workspace_id)}
|
||
)
|
||
|
||
# 获取配置快照
|
||
config_snapshot = await self._get_config_snapshot(app_id)
|
||
|
||
# 创建新会话
|
||
new_conv_id = str(uuid.uuid4())
|
||
new_conversation = ConversationModel(
|
||
id=uuid.UUID(new_conv_id),
|
||
app_id=app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
is_draft=True,
|
||
title="草稿会话",
|
||
config_snapshot=config_snapshot
|
||
)
|
||
self.db.add(new_conversation)
|
||
self.db.commit()
|
||
self.db.refresh(new_conversation)
|
||
|
||
# 如果有开场白,作为第一条 assistant 消息写入数据库
|
||
if opening_statement:
|
||
conversation_service.add_message(
|
||
conversation_id=uuid.UUID(new_conv_id),
|
||
role="assistant",
|
||
content=opening_statement,
|
||
meta_data={"suggested_questions": suggested_questions}
|
||
)
|
||
logger.debug(f"已保存开场白到会话 {new_conv_id}")
|
||
|
||
logger.info(
|
||
"创建草稿会话成功",
|
||
extra={
|
||
"conversation_id": new_conv_id,
|
||
"workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
|
||
return new_conv_id
|
||
|
||
# 如果提供了会话ID,验证其存在性和工作空间归属
|
||
try:
|
||
conv_uuid = uuid.UUID(conversation_id)
|
||
conversation = conversation_service.get_conversation(conv_uuid)
|
||
|
||
# 验证会话属于当前工作空间(或属于共享应用的源工作空间)
|
||
# sub_agent 内部调用时跳过校验,已在上层验证过
|
||
if not sub_agent and conversation.workspace_id != workspace_id:
|
||
# 检查是否是共享应用的会话(被共享者 workspace 访问源应用)
|
||
from app.models import AppShare
|
||
from sqlalchemy import select as sa_select
|
||
share = self.db.scalars(
|
||
sa_select(AppShare).where(
|
||
AppShare.source_app_id == app_id,
|
||
AppShare.target_workspace_id == workspace_id
|
||
)
|
||
).first()
|
||
|
||
# 情况2:sub_agent 内部调用时,workspace_id 是源应用的 workspace,
|
||
# 而会话是被共享者创建的,只要会话属于同一个 app 即可放行
|
||
same_app = (conversation.app_id == app_id)
|
||
|
||
if not share and not same_app:
|
||
logger.warning(
|
||
"会话不属于当前工作空间",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"conversation_workspace_id": str(conversation.workspace_id),
|
||
"current_workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
raise BusinessException(
|
||
"会话不属于当前工作空间",
|
||
BizCode.PERMISSION_DENIED
|
||
)
|
||
|
||
logger.debug(
|
||
"使用现有会话",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
return conversation_id
|
||
except BusinessException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(
|
||
"会话不存在或无效",
|
||
extra={"conversation_id": conversation_id, "error": str(e)}
|
||
)
|
||
raise BusinessException(
|
||
f"会话不存在: {conversation_id}",
|
||
BizCode.NOT_FOUND,
|
||
cause=e
|
||
)
|
||
|
||
async def _load_conversation_history(
|
||
self,
|
||
conversation_id: str,
|
||
max_history: int = 10,
|
||
current_provider: Optional[str] = None,
|
||
current_is_omni: Optional[bool] = None
|
||
) -> List[Dict[str, str]]:
|
||
"""加载会话历史消息,并根据当前模型配置处理多模态文件
|
||
|
||
Args:
|
||
conversation_id: 会话ID
|
||
max_history: 最大历史消息数量
|
||
current_provider: 当前模型的provider
|
||
current_is_omni: 当前模型的is_omni
|
||
|
||
Returns:
|
||
List[Dict]: 历史消息列表
|
||
"""
|
||
try:
|
||
|
||
conversation_service = ConversationService(self.db)
|
||
# 获取 API 配置用于多模态处理
|
||
history = await conversation_service.get_conversation_history(
|
||
conversation_id=uuid.UUID(conversation_id),
|
||
max_history=max_history,
|
||
current_provider=current_provider,
|
||
current_is_omni=current_is_omni
|
||
)
|
||
|
||
logger.debug(
|
||
"加载会话历史",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"max_history": max_history,
|
||
"loaded_count": len(history)
|
||
}
|
||
)
|
||
|
||
return history
|
||
|
||
except Exception as e:
|
||
# 新会话没有历史记录是正常的
|
||
logger.debug("加载会话历史失败(可能是新会话)", extra={"error": str(e)})
|
||
return []
|
||
|
||
async def _save_conversation_message(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
assistant_message: str,
|
||
meta_data: dict,
|
||
app_id: Optional[uuid.UUID] = None,
|
||
user_id: Optional[str] = None,
|
||
files: Optional[List[FileInput]] = None,
|
||
processed_files: Optional[List[Dict[str, Any]]] = None,
|
||
audio_url: Optional[str] = None,
|
||
citations: Optional[List[Any]] = None,
|
||
provider: Optional[str] = None,
|
||
is_omni: Optional[bool] = None
|
||
) -> None:
|
||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||
|
||
Args:
|
||
conversation_id: 会话ID
|
||
user_message: 用户消息
|
||
assistant_message: AI 回复消息
|
||
app_id: 应用ID(未使用,保留用于兼容性)
|
||
user_id: 用户ID(未使用,保留用于兼容性)
|
||
meta_data: token消耗
|
||
files: 原始文件输入
|
||
processed_files: 处理后的文件
|
||
audio_url: 音频URL
|
||
citations: 引用来源列表
|
||
provider: 模型供应商
|
||
is_omni: 是否为全模态模型
|
||
"""
|
||
try:
|
||
from app.services.conversation_service import ConversationService
|
||
|
||
conversation_service = ConversationService(self.db)
|
||
conv_uuid = uuid.UUID(conversation_id)
|
||
|
||
# 保存消息(会话已经存在)
|
||
human_meta = {
|
||
"files": [],
|
||
"history_files": {}
|
||
}
|
||
if files:
|
||
from app.models.file_metadata_model import FileMetadata
|
||
local_ids = [f.upload_file_id for f in files
|
||
if f.transfer_method.value == "local_file" and f.upload_file_id
|
||
and (not f.name or not f.size)]
|
||
meta_map = {}
|
||
if local_ids:
|
||
rows = self.db.query(FileMetadata).filter(
|
||
FileMetadata.id.in_(local_ids),
|
||
FileMetadata.status == "completed"
|
||
).all()
|
||
meta_map = {str(r.id): r for r in rows}
|
||
for f in files:
|
||
name, size = f.name, f.size
|
||
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
|
||
meta = meta_map.get(str(f.upload_file_id))
|
||
if meta:
|
||
name = name or meta.file_name
|
||
size = size or meta.file_size
|
||
human_meta["files"].append({
|
||
"type": f.type,
|
||
"url": f.url,
|
||
"file_type": f.file_type,
|
||
"name": name,
|
||
"size": size
|
||
})
|
||
|
||
# 保存 history_files,包含 provider 和 is_omni 信息
|
||
if processed_files:
|
||
human_meta["history_files"] = {
|
||
"content": processed_files,
|
||
"provider": provider,
|
||
"is_omni": is_omni
|
||
}
|
||
|
||
# 保存用户消息
|
||
conversation_service.add_message(
|
||
conversation_id=conv_uuid,
|
||
role="user",
|
||
content=user_message,
|
||
meta_data=human_meta
|
||
)
|
||
# 保存助手消息(含 audio_url 和 citations)
|
||
if audio_url:
|
||
meta_data["audio_url"] = audio_url
|
||
if citations:
|
||
meta_data["citations"] = citations
|
||
conversation_service.add_message(
|
||
conversation_id=conv_uuid,
|
||
role="assistant",
|
||
content=assistant_message,
|
||
meta_data=meta_data
|
||
)
|
||
|
||
logger.debug(
|
||
"保存会话消息",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"user_message_length": len(user_message),
|
||
"assistant_message_length": len(assistant_message)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning("保存会话消息失败", extra={"error": str(e)})
|
||
|
||
async def _get_config_snapshot(self, app_id: uuid.UUID) -> Dict[str, Any]:
|
||
"""获取当前配置快照
|
||
|
||
Args:
|
||
app_id: 应用ID
|
||
|
||
Returns:
|
||
Dict: 配置快照
|
||
"""
|
||
try:
|
||
from app.models import AgentConfig, ModelConfig
|
||
|
||
# 获取 Agent 配置
|
||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||
agent_cfg = self.db.scalars(stmt).first()
|
||
|
||
if not agent_cfg:
|
||
return {}
|
||
|
||
# 获取模型配置
|
||
model_config = None
|
||
if agent_cfg.default_model_config_id:
|
||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||
|
||
# 构建快照(确保所有值都可序列化)
|
||
def safe_serialize(value):
|
||
"""安全序列化值"""
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, (str, int, float, bool)):
|
||
return value
|
||
if isinstance(value, (dict, list)):
|
||
return value
|
||
# 对于 Pydantic 模型或其他对象,尝试转换为字典
|
||
if hasattr(value, 'dict'):
|
||
return value.dict()
|
||
if hasattr(value, '__dict__'):
|
||
return value.__dict__
|
||
return str(value)
|
||
|
||
snapshot = {
|
||
"agent_config": {
|
||
"system_prompt": agent_cfg.system_prompt,
|
||
"model_parameters": safe_serialize(agent_cfg.model_parameters),
|
||
"knowledge_retrieval": safe_serialize(agent_cfg.knowledge_retrieval),
|
||
"memory": safe_serialize(agent_cfg.memory),
|
||
"variables": safe_serialize(agent_cfg.variables),
|
||
"tools": safe_serialize(agent_cfg.tools)
|
||
},
|
||
"model_config": {
|
||
"model_name": model_config.name if model_config else None,
|
||
"provider": model_config.provider if model_config else None,
|
||
"type": model_config.type if model_config else None
|
||
} if model_config else None,
|
||
"snapshot_time": datetime.datetime.now().isoformat()
|
||
}
|
||
|
||
return snapshot
|
||
|
||
except Exception as e:
|
||
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
|
||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
||
return {}
|
||
|
||
async def _generate_suggested_questions(
|
||
self,
|
||
features_config: Dict[str, Any],
|
||
assistant_message: str,
|
||
api_key_config: Dict[str, Any],
|
||
effective_params: Dict[str, Any]
|
||
) -> List[str]:
|
||
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
|
||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
|
||
return []
|
||
try:
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
llm = ChatOpenAI(
|
||
model=api_key_config["model_name"],
|
||
api_key=api_key_config["api_key"],
|
||
base_url=api_key_config.get("api_base"),
|
||
temperature=0.5,
|
||
max_tokens=200,
|
||
)
|
||
prompt = (
|
||
f"根据以下AI回复,生成3个用户可能继续追问的简短问题,每行一个,不加序号:\n\n{assistant_message}"
|
||
)
|
||
resp = await llm.ainvoke([HumanMessage(content=prompt)])
|
||
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
|
||
return lines[:3]
|
||
except Exception as e:
|
||
logger.warning(f"生成建议问题失败: {e}")
|
||
return []
|
||
|
||
async def _generate_tts(
|
||
self,
|
||
features_config: Dict[str, Any],
|
||
text: str,
|
||
api_key_config: Dict[str, Any],
|
||
tenant_id: Optional[uuid.UUID] = None,
|
||
workspace_id: Optional[uuid.UUID] = None,
|
||
) -> Optional[str]:
|
||
"""先注册文件元数据并返回 audio_url,再后台流式写入音频内容"""
|
||
tts_config = features_config.get("text_to_speech", {})
|
||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||
return None
|
||
if not text or not text.strip():
|
||
return None
|
||
|
||
from app.models.file_metadata_model import FileMetadata
|
||
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||
|
||
provider = api_key_config.get("provider", "openai")
|
||
api_key = api_key_config.get("api_key")
|
||
api_base = api_key_config.get("api_base")
|
||
voice = tts_config.get("voice")
|
||
file_ext, content_type = ".mp3", "audio/mpeg"
|
||
|
||
file_id = uuid.uuid4()
|
||
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||
|
||
# 先写入 pending 状态的元数据,立即返回 URL
|
||
db_file = FileMetadata(
|
||
id=file_id,
|
||
tenant_id=tenant_id,
|
||
workspace_id=workspace_id,
|
||
file_key=file_key,
|
||
file_name=f"tts_{file_id}{file_ext}",
|
||
file_ext=file_ext,
|
||
file_size=0,
|
||
content_type=content_type,
|
||
status="pending",
|
||
)
|
||
self.db.add(db_file)
|
||
self.db.commit()
|
||
|
||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||
|
||
# 后台任务:流式生成并写入存储,完成后更新状态
|
||
async def _stream_to_storage():
|
||
try:
|
||
storage_service = FileStorageService()
|
||
if provider == "dashscope":
|
||
stream = self._tts_dashscope_stream(
|
||
api_key=api_key,
|
||
text=text,
|
||
voice=voice or "longxiaochun",
|
||
tts_config=tts_config,
|
||
)
|
||
else:
|
||
stream = self._tts_openai_stream(
|
||
api_key=api_key,
|
||
api_base=api_base,
|
||
text=text,
|
||
voice=voice or "alloy",
|
||
)
|
||
|
||
total_size = await storage_service.upload_stream(
|
||
tenant_id=tenant_id,
|
||
workspace_id=workspace_id,
|
||
file_id=file_id,
|
||
file_ext=file_ext,
|
||
stream=stream,
|
||
content_type=content_type,
|
||
)
|
||
|
||
# 更新元数据状态
|
||
with get_db_context() as bg_db:
|
||
record = bg_db.get(FileMetadata, file_id)
|
||
if record:
|
||
record.status = "completed"
|
||
record.file_size = total_size
|
||
bg_db.commit()
|
||
logger.debug(f"TTS 流式写入完成,provider={provider}, file_key={file_key}")
|
||
except Exception as e:
|
||
logger.warning(f"TTS 流式写入失败: {e}")
|
||
with get_db_context() as bg_db:
|
||
record = bg_db.get(FileMetadata, file_id)
|
||
if record:
|
||
record.status = "failed"
|
||
bg_db.commit()
|
||
|
||
asyncio.create_task(_stream_to_storage())
|
||
return audio_url
|
||
|
||
async def _generate_tts_streaming(
|
||
self,
|
||
features_config: Dict[str, Any],
|
||
api_key_config: Dict[str, Any],
|
||
text_queue: asyncio.Queue,
|
||
tenant_id: Optional[uuid.UUID] = None,
|
||
workspace_id: Optional[uuid.UUID] = None,
|
||
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||
"""文本流式输入并行合成音频。
|
||
返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。
|
||
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
|
||
"""
|
||
tts_config = features_config.get("text_to_speech", {})
|
||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||
return None, None
|
||
|
||
from app.models.file_metadata_model import FileMetadata
|
||
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||
|
||
provider = api_key_config.get("provider", "openai")
|
||
api_key = api_key_config.get("api_key")
|
||
api_base = api_key_config.get("api_base")
|
||
voice = tts_config.get("voice")
|
||
file_ext, content_type = ".mp3", "audio/mpeg"
|
||
|
||
file_id = uuid.uuid4()
|
||
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||
|
||
db_file = FileMetadata(
|
||
id=file_id,
|
||
tenant_id=tenant_id,
|
||
workspace_id=workspace_id,
|
||
file_key=file_key,
|
||
file_name=f"tts_{file_id}{file_ext}",
|
||
file_ext=file_ext,
|
||
file_size=0,
|
||
content_type=content_type,
|
||
status="pending",
|
||
)
|
||
self.db.add(db_file)
|
||
self.db.commit()
|
||
|
||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||
|
||
async def _run():
|
||
try:
|
||
storage_service = FileStorageService()
|
||
if provider == "dashscope":
|
||
audio_stream = self._tts_dashscope_stream_from_queue(
|
||
api_key=api_key,
|
||
voice=voice or "longxiaochun",
|
||
tts_config=tts_config,
|
||
text_queue=text_queue,
|
||
)
|
||
else:
|
||
audio_stream = self._tts_openai_stream_from_queue(
|
||
api_key=api_key,
|
||
api_base=api_base,
|
||
voice=voice or "alloy",
|
||
text_queue=text_queue,
|
||
)
|
||
total_size = await storage_service.upload_stream(
|
||
tenant_id=tenant_id,
|
||
workspace_id=workspace_id,
|
||
file_id=file_id,
|
||
file_ext=file_ext,
|
||
stream=audio_stream,
|
||
content_type=content_type,
|
||
)
|
||
with get_db_context() as bg_db:
|
||
record = bg_db.get(FileMetadata, file_id)
|
||
if record:
|
||
record.status = "completed"
|
||
record.file_size = total_size
|
||
bg_db.commit()
|
||
logger.debug(f"TTS 流式合成完成,provider={provider}, file_key={file_key}")
|
||
except Exception as e:
|
||
logger.warning(f"TTS 流式合成失败: {e}")
|
||
with get_db_context() as bg_db:
|
||
record = bg_db.get(FileMetadata, file_id)
|
||
if record:
|
||
record.status = "failed"
|
||
bg_db.commit()
|
||
|
||
task = asyncio.create_task(_run())
|
||
return audio_url, task
|
||
|
||
@staticmethod
|
||
async def _tts_openai_stream_from_queue(
|
||
api_key: str,
|
||
api_base: Optional[str],
|
||
voice: str,
|
||
text_queue: asyncio.Queue,
|
||
):
|
||
"""OpenAI TTS:收集全部文本后流式合成(OpenAI 不支持增量输入)"""
|
||
from openai import AsyncOpenAI
|
||
# 收集全部文本(此时文本流已并行输出,等待时间短)
|
||
parts = []
|
||
while True:
|
||
chunk = await text_queue.get()
|
||
if chunk is None:
|
||
break
|
||
parts.append(chunk)
|
||
full_text = "".join(parts)
|
||
if not full_text.strip():
|
||
return
|
||
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||
async with client.audio.speech.with_streaming_response.create(
|
||
model="tts-1",
|
||
voice=voice,
|
||
input=full_text[:4096],
|
||
) as response:
|
||
async for chunk in response.iter_bytes(chunk_size=4096):
|
||
yield chunk
|
||
|
||
@staticmethod
|
||
async def _tts_dashscope_stream_from_queue(
|
||
api_key: str,
|
||
voice: str,
|
||
tts_config: Dict[str, Any],
|
||
text_queue: asyncio.Queue,
|
||
):
|
||
"""DashScope TTS:文本流式输入,实现真正并行合成"""
|
||
import dashscope
|
||
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
|
||
|
||
model = tts_config.get("model") or "cosyvoice-v2"
|
||
is_v2 = model.endswith("-v2")
|
||
if is_v2 and not voice.endswith("_v2"):
|
||
voice = voice + "_v2"
|
||
elif not is_v2 and voice.endswith("_v2"):
|
||
voice = voice[:-3]
|
||
|
||
audio_queue: asyncio.Queue = asyncio.Queue()
|
||
loop = asyncio.get_event_loop()
|
||
|
||
class _Callback(ResultCallback):
|
||
def on_data(self, data: bytes):
|
||
if data:
|
||
loop.call_soon_threadsafe(audio_queue.put_nowait, data)
|
||
def on_complete(self):
|
||
loop.call_soon_threadsafe(audio_queue.put_nowait, None)
|
||
def on_error(self, message):
|
||
loop.call_soon_threadsafe(audio_queue.put_nowait, RuntimeError(str(message)))
|
||
def on_open(self): pass
|
||
def on_close(self): pass
|
||
|
||
dashscope.api_key = api_key
|
||
synthesizer = SpeechSynthesizer(
|
||
model=model,
|
||
voice=voice,
|
||
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||
callback=_Callback(),
|
||
)
|
||
|
||
async def _feed_text():
|
||
"""从 text_queue 取文本按句子切分后喂给 synthesizer"""
|
||
import re
|
||
buf = ""
|
||
sentence_end = re.compile(r'[\u3002\uff01\uff1f.!?\n]')
|
||
while True:
|
||
chunk = await text_queue.get()
|
||
if chunk is None:
|
||
if buf.strip():
|
||
await asyncio.to_thread(synthesizer.streaming_call, buf)
|
||
await asyncio.to_thread(synthesizer.streaming_complete)
|
||
break
|
||
buf += chunk
|
||
# 按句子切分喂入
|
||
while sentence_end.search(buf):
|
||
m = sentence_end.search(buf)
|
||
sentence = buf[:m.end()]
|
||
buf = buf[m.end():]
|
||
await asyncio.to_thread(synthesizer.streaming_call, sentence)
|
||
|
||
asyncio.create_task(_feed_text())
|
||
|
||
while True:
|
||
item = await audio_queue.get()
|
||
if item is None:
|
||
break
|
||
if isinstance(item, Exception):
|
||
raise item
|
||
yield item
|
||
|
||
@staticmethod
|
||
async def _tts_openai_stream(
|
||
api_key: str,
|
||
api_base: Optional[str],
|
||
text: str,
|
||
voice: str,
|
||
):
|
||
"""OpenAI 兼容 TTS 流式生成,yield bytes chunks"""
|
||
from openai import AsyncOpenAI
|
||
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||
async with client.audio.speech.with_streaming_response.create(
|
||
model="tts-1",
|
||
voice=voice,
|
||
input=text[:4096],
|
||
) as response:
|
||
async for chunk in response.iter_bytes(chunk_size=4096):
|
||
yield chunk
|
||
|
||
@staticmethod
|
||
async def _tts_dashscope_stream(
|
||
api_key: str,
|
||
text: str,
|
||
voice: str,
|
||
tts_config: Dict[str, Any],
|
||
):
|
||
"""DashScope TTS 流式生成,yield bytes chunks"""
|
||
import dashscope
|
||
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
|
||
|
||
model = tts_config.get("model") or "cosyvoice-v2"
|
||
is_v2 = model.endswith("-v2")
|
||
if is_v2 and not voice.endswith("_v2"):
|
||
voice = voice + "_v2"
|
||
elif not is_v2 and voice.endswith("_v2"):
|
||
voice = voice[:-3]
|
||
|
||
queue: asyncio.Queue = asyncio.Queue()
|
||
loop = asyncio.get_event_loop()
|
||
|
||
class _Callback(ResultCallback):
|
||
def on_data(self, data: bytes):
|
||
if data:
|
||
loop.call_soon_threadsafe(queue.put_nowait, data)
|
||
def on_complete(self):
|
||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||
def on_error(self, message):
|
||
loop.call_soon_threadsafe(queue.put_nowait, RuntimeError(str(message)))
|
||
def on_open(self): pass
|
||
def on_close(self): pass
|
||
|
||
def _sync_stream():
|
||
dashscope.api_key = api_key
|
||
synthesizer = SpeechSynthesizer(
|
||
model=model,
|
||
voice=voice,
|
||
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||
callback=_Callback(),
|
||
)
|
||
synthesizer.streaming_call(text[:4096])
|
||
synthesizer.streaming_complete()
|
||
|
||
asyncio.create_task(asyncio.to_thread(_sync_stream))
|
||
while True:
|
||
item = await queue.get()
|
||
if item is None:
|
||
break
|
||
if isinstance(item, Exception):
|
||
raise item
|
||
yield item
|
||
|
||
def _replace_variables(
|
||
self,
|
||
text: str,
|
||
values: Dict[str, Any],
|
||
definitions: List[Dict[str, Any]]
|
||
) -> str:
|
||
"""替换文本中的变量
|
||
|
||
Args:
|
||
text: 原始文本
|
||
values: 变量值
|
||
definitions: 变量定义
|
||
|
||
Returns:
|
||
str: 替换后的文本
|
||
"""
|
||
result = text
|
||
|
||
# 创建变量定义映射
|
||
var_defs = {var["name"]: var for var in definitions}
|
||
|
||
for var_name, var_value in values.items():
|
||
# 检查变量是否在定义中
|
||
if var_name not in var_defs:
|
||
logger.warning(f"未定义的变量: {var_name}")
|
||
continue
|
||
|
||
# 替换变量(支持多种格式)
|
||
placeholders = [
|
||
f"{{{{{var_name}}}}}", # {{var_name}}
|
||
f"{{{var_name}}}", # {var_name}
|
||
f"${{{var_name}}}", # ${var_name}
|
||
]
|
||
|
||
for placeholder in placeholders:
|
||
if placeholder in result:
|
||
result = result.replace(placeholder, str(var_value))
|
||
|
||
return result
|
||
|
||
# ==================== 多模型对比试运行 ====================
|
||
|
||
async def run_compare(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
models: List[Dict[str, Any]],
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
parallel: bool = True,
|
||
timeout: int = 60,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True,
|
||
memory: bool = True,
|
||
files: list[FileInput] | None = None
|
||
) -> Dict[str, Any]:
|
||
"""多模型对比试运行
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID
|
||
conversation_id: 会话ID
|
||
user_id: 用户ID
|
||
variables: 变量参数
|
||
parallel: 是否并行执行
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
Dict: 对比结果
|
||
"""
|
||
logger.info(
|
||
"多模型对比试运行",
|
||
extra={
|
||
"model_count": len(models),
|
||
"parallel": parallel
|
||
}
|
||
)
|
||
|
||
# 提前校验文件上传(与 run() 内部保持一致)
|
||
features_config: dict = agent_config.features or {}
|
||
if hasattr(features_config, 'model_dump'):
|
||
features_config = features_config.model_dump()
|
||
# self._validate_file_upload(features_config, files)
|
||
|
||
async def run_single_model(model_info):
|
||
"""运行单个模型"""
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 临时修改参数(不使用 deepcopy 避免 SQLAlchemy 会话问题)
|
||
original_params = agent_config.model_parameters
|
||
agent_config.model_parameters = model_info["parameters"]
|
||
|
||
# 使用模型自己的 conversation_id,如果没有则使用全局的
|
||
model_conversation_id = model_info.get("conversation_id") or conversation_id
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self.run(
|
||
agent_config=agent_config,
|
||
model_config=model_info["model_config"],
|
||
message=message,
|
||
workspace_id=workspace_id,
|
||
conversation_id=model_conversation_id,
|
||
user_id=user_id,
|
||
variables=variables,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id,
|
||
web_search=web_search,
|
||
memory=memory,
|
||
files=files
|
||
),
|
||
timeout=timeout
|
||
)
|
||
finally:
|
||
# 恢复原始参数
|
||
agent_config.model_parameters = original_params
|
||
|
||
elapsed = time.time() - start_time
|
||
usage = result.get("usage", {})
|
||
|
||
return {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_info["label"],
|
||
"conversation_id": result['conversation_id'],
|
||
"parameters_used": model_info["parameters"],
|
||
"message": result.get("message"),
|
||
"reasoning_content": result.get("reasoning_content"),
|
||
"usage": usage,
|
||
"elapsed_time": elapsed,
|
||
"tokens_per_second": (
|
||
usage.get("completion_tokens", 0) / elapsed
|
||
if elapsed > 0 and usage.get("completion_tokens") else None
|
||
),
|
||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||
"audio_url": result.get("audio_url"),
|
||
"audio_status": result.get("audio_status"),
|
||
"citations": result.get("citations", []),
|
||
"suggested_questions": result.get("suggested_questions", []),
|
||
"error": None
|
||
}
|
||
|
||
except TimeoutError:
|
||
logger.warning(
|
||
"模型运行超时",
|
||
extra={
|
||
"model_config_id": str(model_info["model_config_id"]),
|
||
"timeout": timeout
|
||
}
|
||
)
|
||
return {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"conversation_id": conversation_id,
|
||
"label": model_info["label"],
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": timeout,
|
||
"error": f"执行超时({timeout}秒)"
|
||
}
|
||
except Exception as e:
|
||
logger.error(
|
||
"模型运行失败",
|
||
extra={
|
||
"model_config_id": str(model_info["model_config_id"]),
|
||
"error": str(e)
|
||
}
|
||
)
|
||
return {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_info["label"],
|
||
"conversation_id": conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": 0,
|
||
"error": str(e)
|
||
}
|
||
|
||
# 执行所有模型(并行或串行)
|
||
if parallel:
|
||
logger.debug(f"并行执行 {len(models)} 个模型")
|
||
results = await asyncio.gather(
|
||
*[run_single_model(m) for m in models],
|
||
return_exceptions=False
|
||
)
|
||
else:
|
||
logger.debug(f"串行执行 {len(models)} 个模型")
|
||
results = []
|
||
for model_info in models:
|
||
result = await run_single_model(model_info)
|
||
results.append(result)
|
||
|
||
# 统计分析
|
||
successful = [r for r in results if not r.get("error")]
|
||
failed = [r for r in results if r.get("error")]
|
||
|
||
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
|
||
cheapest = min(
|
||
successful,
|
||
key=lambda x: x.get("cost_estimate") or float("inf")
|
||
) if successful else None
|
||
|
||
logger.info(
|
||
"多模型对比完成",
|
||
extra={
|
||
"successful": len(successful),
|
||
"failed": len(failed),
|
||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||
}
|
||
)
|
||
|
||
return {
|
||
"results": [{
|
||
**r,
|
||
"audio_url": r.get("audio_url"),
|
||
"audio_status": r.get("audio_status"),
|
||
"citations": r.get("citations", []),
|
||
"suggested_questions": r.get("suggested_questions", []),
|
||
} for r in results],
|
||
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
||
"successful_count": len(successful),
|
||
"failed_count": len(failed),
|
||
"fastest_model": fastest["label"] if fastest else None,
|
||
"cheapest_model": cheapest["label"] if cheapest else None
|
||
}
|
||
|
||
def _estimate_cost(self, usage: Dict[str, Any], model_config) -> Optional[float]:
|
||
"""估算成本
|
||
|
||
Args:
|
||
usage: Token 使用情况
|
||
model_config: 模型配置
|
||
|
||
Returns:
|
||
Optional[float]: 估算成本(美元)
|
||
"""
|
||
if not usage:
|
||
return None
|
||
|
||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||
completion_tokens = usage.get("completion_tokens", 0)
|
||
|
||
# 简化成本估算:暂时返回 None
|
||
# TODO: 实现基于模型名称或配置的成本估算
|
||
# 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段
|
||
return None
|
||
|
||
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> tuple[AgentConfig, Any]:
|
||
"""创建一个带有覆盖参数的 agent_config(浅拷贝,只修改 model_parameters)
|
||
|
||
Args:
|
||
agent_config: 原始 Agent 配置
|
||
parameters: 要覆盖的参数
|
||
|
||
Returns:
|
||
AgentConfig: 修改后的配置(注意:这是同一个对象,只是临时修改了 model_parameters)
|
||
"""
|
||
# 保存原始参数
|
||
original_params = agent_config.model_parameters
|
||
# 设置新参数
|
||
agent_config.model_parameters = parameters
|
||
return agent_config, original_params
|
||
|
||
async def run_compare_stream(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
models: List[Dict[str, Any]],
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True,
|
||
memory: bool = True,
|
||
parallel: bool = True,
|
||
timeout: int = 60,
|
||
files: list[FileInput] | None = None
|
||
) -> AsyncGenerator[str, None]:
|
||
"""多模型对比试运行(流式返回)
|
||
|
||
参考 run_compare 的实现,支持并行或串行执行
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID
|
||
conversation_id: 会话ID
|
||
user_id: 用户ID
|
||
variables: 变量参数
|
||
storage_type: 存储类型
|
||
user_rag_memory_id: RAG 记忆 ID
|
||
web_search: 是否启用网络搜索
|
||
memory: 是否启用记忆
|
||
parallel: 是否并行执行
|
||
timeout: 超时时间(秒)
|
||
files: 多模态文件
|
||
|
||
Yields:
|
||
str: SSE 格式的事件数据
|
||
"""
|
||
logger.info(
|
||
"多模型对比流式试运行",
|
||
extra={"model_count": len(models), "parallel": parallel}
|
||
)
|
||
|
||
# 提前校验文件上传
|
||
# features_config: dict = agent_config.features or {}
|
||
# if hasattr(features_config, 'model_dump'):
|
||
# features_config = features_config.model_dump()
|
||
# self._validate_file_upload(features_config, files)
|
||
|
||
# 发送开始事件
|
||
yield self._format_sse_event("compare_start", {
|
||
"conversation_id": conversation_id,
|
||
"model_count": len(models),
|
||
"parallel": parallel,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
results = []
|
||
|
||
async def run_single_model_stream(idx: int, model_info: Dict[str, Any], event_queue: asyncio.Queue):
|
||
"""运行单个模型(流式)并将事件放入队列"""
|
||
model_label = model_info["label"]
|
||
model_config_id = str(model_info["model_config_id"])
|
||
# 使用模型自己的 conversation_id,如果没有则使用全局的
|
||
model_conversation_id = model_info.get("conversation_id") or conversation_id
|
||
|
||
try:
|
||
# 发送模型开始事件
|
||
await event_queue.put(self._format_sse_event("model_start", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
start_time = time.time()
|
||
full_content = ""
|
||
full_reasoning = ""
|
||
returned_conversation_id = model_conversation_id
|
||
audio_url = None
|
||
audio_status = None
|
||
citations = []
|
||
suggested_questions = []
|
||
|
||
# 临时修改参数
|
||
original_params = agent_config.model_parameters
|
||
agent_config.model_parameters = model_info["parameters"]
|
||
|
||
try:
|
||
# 流式调用单个模型
|
||
async for event_str in self.run_stream(
|
||
agent_config=agent_config,
|
||
model_config=model_info["model_config"],
|
||
message=message,
|
||
workspace_id=workspace_id,
|
||
conversation_id=model_conversation_id,
|
||
user_id=user_id,
|
||
variables=variables,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id,
|
||
web_search=web_search,
|
||
memory=memory,
|
||
files=files
|
||
):
|
||
# 解析原始事件
|
||
try:
|
||
lines = event_str.strip().split('\n')
|
||
event_type = None
|
||
event_data = None
|
||
|
||
for line in lines:
|
||
if line.startswith('event: '):
|
||
event_type = line[7:].strip()
|
||
elif line.startswith('data: '):
|
||
event_data = json.loads(line[6:])
|
||
|
||
# 从 start 事件中获取实际的 conversation_id
|
||
if event_type == "start" and event_data:
|
||
conv_id = event_data.get("conversation_id")
|
||
if conv_id:
|
||
returned_conversation_id = conv_id
|
||
|
||
# 累积消息内容
|
||
if event_type == "message" and event_data:
|
||
chunk = event_data.get("content", "")
|
||
full_content += chunk
|
||
|
||
# 转发消息块事件(带模型标识)
|
||
await event_queue.put(self._format_sse_event("model_message", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"content": chunk
|
||
}))
|
||
|
||
# 转发深度思考事件(带模型标识)
|
||
if event_type == "reasoning" and event_data:
|
||
reasoning_chunk = event_data.get("content", "")
|
||
full_reasoning += reasoning_chunk
|
||
await event_queue.put(self._format_sse_event("model_reasoning", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"content": event_data.get("content", "")
|
||
}))
|
||
|
||
# 从 end 事件中提取 features 输出字段
|
||
if event_type == "end" and event_data:
|
||
audio_url = event_data.get("audio_url")
|
||
audio_status = event_data.get("audio_status")
|
||
citations = event_data.get("citations", [])
|
||
suggested_questions = event_data.get("suggested_questions", [])
|
||
|
||
if event_type == "error" and event_data:
|
||
await event_queue.put(self._format_sse_event("model_error", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"error": event_data.get("error", "未知错误")
|
||
}))
|
||
except Exception as e:
|
||
logger.warning(f"解析流式事件失败: {e}")
|
||
finally:
|
||
# 恢复原始参数
|
||
agent_config.model_parameters = original_params
|
||
|
||
elapsed = time.time() - start_time
|
||
|
||
# 构建结果(参考 run_compare)
|
||
result = {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"message": full_content,
|
||
"reasoning_content": full_reasoning or None,
|
||
"elapsed_time": elapsed,
|
||
"audio_url": audio_url,
|
||
"audio_status": audio_status,
|
||
"citations": citations,
|
||
"suggested_questions": suggested_questions,
|
||
"error": None
|
||
}
|
||
|
||
# 发送模型完成事件
|
||
await event_queue.put(self._format_sse_event("model_end", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"elapsed_time": elapsed,
|
||
"message_length": len(full_content),
|
||
"audio_url": audio_url,
|
||
"audio_status": audio_status,
|
||
"citations": citations,
|
||
"suggested_questions": suggested_questions,
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
return result
|
||
|
||
except TimeoutError:
|
||
logger.warning(f"模型运行超时: {model_label}")
|
||
result = {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": timeout,
|
||
"error": f"执行超时({timeout}秒)"
|
||
}
|
||
|
||
await event_queue.put(self._format_sse_event("model_error", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"error": result["error"],
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型运行失败: {model_label}, error: {e}")
|
||
result = {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": 0,
|
||
"error": str(e)
|
||
}
|
||
|
||
await event_queue.put(self._format_sse_event("model_error", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"error": str(e),
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
return result
|
||
|
||
if parallel:
|
||
# 并行执行所有模型(参考 run_compare)
|
||
logger.debug(f"并行执行 {len(models)} 个模型(流式)")
|
||
|
||
# 创建事件队列
|
||
event_queue = asyncio.Queue()
|
||
|
||
# 启动所有模型的并行任务
|
||
tasks = [
|
||
asyncio.create_task(run_single_model_stream(idx, model_info, event_queue))
|
||
for idx, model_info in enumerate(models)
|
||
]
|
||
|
||
# 持续从队列中取出事件并转发
|
||
completed_tasks = set()
|
||
while len(completed_tasks) < len(tasks):
|
||
try:
|
||
# 尝试从队列获取事件
|
||
event = await asyncio.wait_for(event_queue.get(), timeout=0.1)
|
||
yield event
|
||
except TimeoutError:
|
||
# 检查是否有任务完成
|
||
for task in tasks:
|
||
if task.done() and task not in completed_tasks:
|
||
completed_tasks.add(task)
|
||
try:
|
||
result = await task
|
||
if result:
|
||
results.append(result)
|
||
except Exception as e:
|
||
logger.error(f"获取任务结果失败: {e}")
|
||
continue
|
||
|
||
# 清空队列中剩余的事件
|
||
while not event_queue.empty():
|
||
try:
|
||
event = event_queue.get_nowait()
|
||
yield event
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
else:
|
||
# 串行执行每个模型(参考 run_compare)
|
||
logger.debug(f"串行执行 {len(models)} 个模型(流式)")
|
||
|
||
for idx, model_info in enumerate(models):
|
||
# 创建临时队列用于单个模型
|
||
event_queue = asyncio.Queue()
|
||
|
||
# 运行单个模型
|
||
result = await run_single_model_stream(idx, model_info, event_queue)
|
||
if result:
|
||
results.append(result)
|
||
|
||
# 转发该模型的所有事件
|
||
while not event_queue.empty():
|
||
try:
|
||
event = event_queue.get_nowait()
|
||
yield event
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
# 统计分析(参考 run_compare)
|
||
successful = [r for r in results if not r.get("error")]
|
||
failed = [r for r in results if r.get("error")]
|
||
|
||
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
|
||
cheapest = min(
|
||
successful,
|
||
key=lambda x: x.get("cost_estimate") or float("inf")
|
||
) if successful else None
|
||
|
||
# 构建结果摘要(包含完整的 message)
|
||
results_summary = []
|
||
for r in results:
|
||
results_summary.append({
|
||
"model_config_id": str(r["model_config_id"]),
|
||
"model_name": r["model_name"],
|
||
"label": r["label"],
|
||
"conversation_id": r.get("conversation_id"),
|
||
"message": r.get("message"),
|
||
"reasoning_content": r.get("reasoning_content"),
|
||
"elapsed_time": r.get("elapsed_time", 0),
|
||
"audio_url": r.get("audio_url"),
|
||
"audio_status": r.get("audio_status"),
|
||
"citations": r.get("citations", []),
|
||
"suggested_questions": r.get("suggested_questions", []),
|
||
"error": r.get("error")
|
||
})
|
||
|
||
# 发送对比完成事件(参考 run_compare 的返回格式)
|
||
yield self._format_sse_event("compare_end", {
|
||
"conversation_id": conversation_id,
|
||
"results": results_summary, # 包含完整结果
|
||
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
||
"successful_count": len(successful),
|
||
"failed_count": len(failed),
|
||
"fastest_model": fastest["label"] if fastest else None,
|
||
"cheapest_model": cheapest["label"] if cheapest else None,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
logger.info(
|
||
"多模型对比流式完成",
|
||
extra={
|
||
"successful": len(successful),
|
||
"failed": len(failed),
|
||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||
}
|
||
)
|