feat(app):

1. Add new functional features to the agent;
2. Enhance the voice output;
3. Modify the end_user binding;
4. Delete and modify the tools.
This commit is contained in:
Timebomb2018
2026-03-16 18:00:09 +08:00
parent b62c40dba3
commit ea391dc44e
22 changed files with 832 additions and 184 deletions

View File

@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
@@ -262,9 +263,12 @@ class AgentRunService:
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置"""
if not tools_config:
return []
tools = []
if web_search:
search_tool = create_web_search_tool({})
tools.append(search_tool)
if not tools_config:
return tools
tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list):
@@ -273,24 +277,15 @@ class AgentRunService:
# 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif tools_config and isinstance(tools_config, dict):
web_search_choice = tools_config.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search and web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools
def load_skill_config(
@@ -373,6 +368,86 @@ class AgentRunService:
)
return tools, bool(memory_config.get("enabled"))
@staticmethod
def _validate_file_upload(
features_config: Dict[str, Any],
files: Optional[List[FileInput]]
) -> None:
"""校验上传文件是否符合 file_upload 配置"""
if not files:
return
fu = features_config.get("file_upload", {})
if not (isinstance(fu, dict) and fu.get("enabled")):
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
max_count = fu.get("max_file_count", 5)
if len(files) > max_count:
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
# 校验传输方式
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
for f in files:
if str(f.transfer_method) not in allowed_methods:
raise BusinessException(
f"不支持的文件传输方式:{f.transfer_method},允许的方式:{', '.join(allowed_methods)}",
BizCode.BAD_REQUEST
)
# 各类型对应的开关和大小限制配置键
type_cfg = {
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
}
for f in files:
ftype = str(f.type) # 如 "image", "audio", "document", "video"
cfg = type_cfg.get(ftype)
if cfg is None:
continue
enabled_key, size_key, default_max_mb, label = cfg
# 校验类型开关
if not fu.get(enabled_key):
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
# 校验文件大小(仅当内容已加载时)
content = f.get_content()
if content is not None:
max_mb = fu.get(size_key, default_max_mb)
size_mb = len(content) / (1024 * 1024)
if size_mb > max_mb:
raise BusinessException(
f"{label}文件大小超过限制(最大 {max_mb}MB当前 {size_mb:.1f}MB",
BizCode.BAD_REQUEST
)
@staticmethod
def _inject_opening_statement(
features_config: Dict[str, Any],
system_prompt: str,
is_new_conversation: bool
) -> str:
"""首轮对话时将开场白注入 system_prompt"""
if not is_new_conversation:
return system_prompt
opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return system_prompt
statement = opening["statement"]
return f"{system_prompt}\n\n[对话开场白]\n{statement}"
@staticmethod
def _filter_citations(
features_config: Dict[str, Any],
citations: List[Any]
) -> List[Any]:
"""根据 citation 开关决定是否返回引用来源"""
citation_cfg = features_config.get("citation", {})
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
return citations
return []
async def run(
self,
*,
@@ -415,6 +490,15 @@ class AgentRunService:
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关(优先级高于参数默认值)
web_search_feature = features_config.get("web_search", {})
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
try:
# 1. 获取 API Key 配置
@@ -449,6 +533,10 @@ class AgentRunService:
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# opening_statement首轮对话注入开场白
is_new_conversation = not conversation_id
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
# 4. 准备工具列表
tools = []
@@ -491,12 +579,10 @@ class AgentRunService:
)
# 6. 加载历史消息
history = []
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
)
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=10
)
# 6. 处理多模态文件
processed_files = None
@@ -551,7 +637,7 @@ class AgentRunService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息
if not sub_agent and memory_config and memory_config.get("enabled"):
if not sub_agent:
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -575,7 +661,15 @@ class AgentRunService:
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
"elapsed_time": elapsed_time,
"suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [],
"citations": self._filter_citations(features_config, result.get("citations", [])),
"audio_url": await self._generate_tts(
features_config, result["content"], api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
) if not sub_agent else None,
}
logger.info(
@@ -630,6 +724,15 @@ class AgentRunService:
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
start_time = time.time()
@@ -659,6 +762,10 @@ class AgentRunService:
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# opening_statement首轮对话注入开场白
is_new_conversation = not conversation_id
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
# 4. 准备工具列表
tools = []
@@ -703,12 +810,10 @@ class AgentRunService:
)
# 6. 加载历史消息
history = []
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=memory_config.get("max_history", 10)
)
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=memory_config.get("max_history", 10)
)
# 6. 处理多模态文件
processed_files = None
@@ -774,7 +879,7 @@ class AgentRunService:
})
# 10. 保存会话消息
if not sub_agent and memory_config and memory_config.get("enabled"):
if not sub_agent:
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -786,12 +891,22 @@ class AgentRunService:
}
)
# 11. 发送结束事件
yield self._format_sse_event("end", {
# 11. 发送结束事件(包含 suggested_questions 和 tts
end_data: Dict[str, Any] = {
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"message_length": len(full_content)
})
}
if not sub_agent:
end_data["suggested_questions"] = await self._generate_suggested_questions(
features_config, full_content, api_key_config, effective_params
)
end_data["audio_url"] = await self._generate_tts(
features_config, full_content, api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
)
end_data["citations"] = self._filter_citations(features_config, [])
yield self._format_sse_event("end", end_data)
logger.info(
"流式试运行完成",
@@ -1137,6 +1252,165 @@ class AgentRunService:
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {}
async def _generate_suggested_questions(
self,
features_config: Dict[str, Any],
assistant_message: str,
api_key_config: Dict[str, Any],
effective_params: Dict[str, Any]
) -> List[str]:
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
sq_config = features_config.get("suggested_questions_after_answer", {})
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
return []
try:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
llm = ChatOpenAI(
model=api_key_config["model_name"],
api_key=api_key_config["api_key"],
base_url=api_key_config.get("api_base"),
temperature=0.5,
max_tokens=200,
)
prompt = (
f"根据以下AI回复生成3个用户可能继续追问的简短问题每行一个不加序号\n\n{assistant_message}"
)
resp = await llm.ainvoke([HumanMessage(content=prompt)])
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
return lines[:3]
except Exception as e:
logger.warning(f"生成建议问题失败: {e}")
return []
async def _generate_tts(
self,
features_config: Dict[str, Any],
text: str,
api_key_config: Dict[str, Any],
tenant_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
) -> Optional[str]:
"""根据 text_to_speech 配置生成语音,上传到存储并返回 URL"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
return None
if not text or not text.strip():
return None
try:
from app.services.file_storage_service import FileStorageService
provider = api_key_config.get("provider", "openai")
api_key = api_key_config.get("api_key")
api_base = api_key_config.get("api_base")
voice = tts_config.get("voice")
if provider == "dashscope":
audio_bytes, file_ext, content_type = await self._tts_dashscope(
api_key=api_key,
text=text,
voice=voice or "longxiaochun", # 会根据 model 版本自动修正后缀
tts_config=tts_config,
)
else:
# OpenAI 兼容接口openai / xinference / gpustack 等)
audio_bytes, file_ext, content_type = await self._tts_openai(
api_key=api_key,
api_base=api_base,
text=text,
voice=voice or "alloy",
)
storage_service = FileStorageService()
file_id = uuid.uuid4()
file_key = await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=audio_bytes,
content_type=content_type,
)
# 保存文件元数据到数据库
from app.models.file_metadata_model import FileMetadata
db_file = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=f"tts_{file_id}{file_ext}",
file_ext=file_ext,
file_size=len(audio_bytes),
content_type=content_type,
status="completed",
)
self.db.add(db_file)
self.db.commit()
server_url = settings.FILE_LOCAL_SERVER_URL
audio_url = f"{server_url}/storage/permanent/{file_id}"
logger.debug(f"TTS 生成成功provider={provider}, file_key={file_key}")
return audio_url
except Exception as e:
logger.warning(f"TTS 生成失败: {e}")
return None
@staticmethod
async def _tts_openai(
api_key: str,
api_base: Optional[str],
text: str,
voice: str,
) -> tuple:
"""OpenAI 兼容 TTS返回 (audio_bytes, file_ext, content_type)"""
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
response = await client.audio.speech.create(
model="tts-1",
voice=voice,
input=text[:4096],
)
return response.content, ".mp3", "audio/mpeg"
@staticmethod
async def _tts_dashscope(
api_key: str,
text: str,
voice: str,
tts_config: Dict[str, Any],
) -> tuple:
"""DashScope CosyVoice TTS返回 (audio_bytes, file_ext, content_type)"""
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
model = tts_config.get("model") or "cosyvoice-v2"
is_v2 = model.endswith("-v2")
# cosyvoice-v2 音色名带 _v2 后缀v1 不带
# 如果用户传入的 voice 不匹配当前模型版本,自动修正
if is_v2 and not voice.endswith("_v2"):
voice = voice + "_v2"
elif not is_v2 and voice.endswith("_v2"):
voice = voice[:-3] # 去掉 _v2
def _sync_call() -> bytes:
dashscope.api_key = api_key
synthesizer = SpeechSynthesizer(
model=model,
voice=voice,
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
)
audio = synthesizer.call(text[:4096])
if audio is None:
raise RuntimeError("DashScope TTS 返回空音频")
return audio
audio_bytes = await asyncio.to_thread(_sync_call)
return audio_bytes, ".mp3", "audio/mpeg"
def _replace_variables(
self,
text: str,
@@ -1221,6 +1495,12 @@ class AgentRunService:
}
)
# 提前校验文件上传(与 run() 内部保持一致)
features_config: dict = agent_config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
async def run_single_model(model_info):
"""运行单个模型"""
try:
@@ -1271,6 +1551,9 @@ class AgentRunService:
if elapsed > 0 and usage.get("completion_tokens") else None
),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"),
"citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []),
"error": None
}
@@ -1343,7 +1626,12 @@ class AgentRunService:
)
return {
"results": results,
"results": [{
**r,
"audio_url": r.get("audio_url"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
} for r in results],
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful),
"failed_count": len(failed),
@@ -1434,6 +1722,12 @@ class AgentRunService:
extra={"model_count": len(models), "parallel": parallel}
)
# 提前校验文件上传
# features_config: dict = agent_config.features or {}
# if hasattr(features_config, 'model_dump'):
# features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
# 发送开始事件
yield self._format_sse_event("compare_start", {
"conversation_id": conversation_id,
@@ -1465,6 +1759,9 @@ class AgentRunService:
start_time = time.time()
full_content = ""
returned_conversation_id = model_conversation_id
audio_url = None
citations = []
suggested_questions = []
# 临时修改参数
original_params = agent_config.model_parameters
@@ -1518,6 +1815,12 @@ class AgentRunService:
"content": chunk
}))
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", [])
if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
@@ -1543,6 +1846,9 @@ class AgentRunService:
"parameters_used": model_info["parameters"],
"message": full_content,
"elapsed_time": elapsed,
"audio_url": audio_url,
"citations": citations,
"suggested_questions": suggested_questions,
"error": None
}
@@ -1554,6 +1860,9 @@ class AgentRunService:
"conversation_id": returned_conversation_id,
"elapsed_time": elapsed,
"message_length": len(full_content),
"audio_url": audio_url,
"citations": citations,
"suggested_questions": suggested_questions,
"timestamp": time.time()
}))
@@ -1685,8 +1994,11 @@ class AgentRunService:
"model_name": r["model_name"],
"label": r["label"],
"conversation_id": r.get("conversation_id"),
"message": r.get("message"), # 包含完整消息
"message": r.get("message"),
"elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
"error": r.get("error")
})