feat(app):

1. Add new functional features to the agent;
2. Enhance the voice output;
3. Modify the end_user binding;
4. Delete and modify the tools.
This commit is contained in:
Timebomb2018
2026-03-16 18:00:09 +08:00
parent b62c40dba3
commit ea391dc44e
22 changed files with 832 additions and 184 deletions

View File

@@ -51,6 +51,9 @@ class AgentConfigConverter:
if hasattr(config, "skills") and config.skills:
result["skills"] = config.skills.model_dump()
if hasattr(config, "features") and config.features:
result["features"] = config.features.model_dump()
return result

View File

@@ -49,12 +49,23 @@ class AppChatService:
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None # 新增:多模态文件
files: Optional[List[FileInput]] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
config_id = None
# 应用 features 配置
features_config: dict = config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# 校验文件上传
self.agent_service._validate_file_upload(features_config, files)
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
@@ -107,17 +118,14 @@ class AppChatService:
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=10
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 处理多模态文件
processed_files = None
@@ -166,6 +174,23 @@ class AppChatService:
elapsed_time = time.time() - start_time
# suggested_questions
suggested_questions = []
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
suggested_questions = await self.agent_service._generate_suggested_questions(
features_config, result["content"],
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base}, {}
)
audio_url = await self.agent_service._generate_tts(
features_config, result["content"],
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
tenant_id=tenant_id, workspace_id=workspace_id
)
return {
"conversation_id": conversation_id,
"message_id": str(message_id),
@@ -175,7 +200,10 @@ class AppChatService:
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
"elapsed_time": elapsed_time,
"suggested_questions": suggested_questions,
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
"audio_url": audio_url,
}
async def agnet_chat_stream(
@@ -190,7 +218,7 @@ class AppChatService:
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None # 新增:多模态文件
files: Optional[List[FileInput]] = None
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
@@ -198,10 +226,19 @@ class AppChatService:
start_time = time.time()
config_id = None
message_id = uuid.uuid4()
yield f"event: start\ndata: {json.dumps({
'conversation_id': str(conversation_id),
"message_id": str(message_id)
}, ensure_ascii=False)}\n\n"
# 应用 features 配置
features_config: dict = config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# 校验文件上传
self.agent_service._validate_file_upload(features_config, files)
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
@@ -327,8 +364,22 @@ class AppChatService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
# 发送结束事件(包含 suggested_questions、tts、citations
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
end_data["suggested_questions"] = await self.agent_service._generate_suggested_questions(
features_config, full_content,
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base}, {}
)
end_data["audio_url"] = await self.agent_service._generate_tts(
features_config, full_content,
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
tenant_id=tenant_id, workspace_id=workspace_id
)
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
@@ -442,7 +493,7 @@ class AppChatService:
try:
message_id = uuid.uuid4()
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n"
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
full_content = ""
total_tokens = 0

View File

@@ -109,7 +109,7 @@ class AppService:
return share is not None
def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
"""验证应用是否可访问(包括共享应用,用于只读操作)
Args:
@@ -360,6 +360,7 @@ class AppService:
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skills=storage_data.get("skills", {}),
features=storage_data.get("features", {}),
is_active=True,
created_at=now,
updated_at=now,
@@ -1073,6 +1074,7 @@ class AppService:
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skills = storage_data.get("skills", {})
agent_cfg.features = storage_data.get("features", {})
agent_cfg.updated_at = now
@@ -1173,6 +1175,7 @@ class AppService:
variables=[],
tools=[],
skills=[],
features={},
is_active=True,
created_at=now,
updated_at=now,
@@ -1389,15 +1392,15 @@ class AppService:
return config.config_id
def _update_endusers_memory_config(
def _update_endusers_memory_config_by_workspace(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
app_id: 应用ID
workspace_id: 工作空间ID
memory_config_id: 新的记忆配置ID
Returns:
@@ -1406,8 +1409,8 @@ class AppService:
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id(
app_id=app_id,
updated_count = repo.batch_update_memory_config_id_by_workspace(
workspace_id=workspace_id,
memory_config_id=memory_config_id
)
@@ -1578,11 +1581,15 @@ class AppService:
)
if memory_config_id:
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
app = self.db.query(App).filter(App.id == app_id).first()
if app:
updated_count = self._update_endusers_memory_config_by_workspace(
app.workspace_id, memory_config_id
)
logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
# 更新当前发布版本指针
app.current_release_id = release.id

View File

@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
@@ -262,9 +263,12 @@ class AgentRunService:
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置"""
if not tools_config:
return []
tools = []
if web_search:
search_tool = create_web_search_tool({})
tools.append(search_tool)
if not tools_config:
return tools
tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list):
@@ -273,24 +277,15 @@ class AgentRunService:
# 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif tools_config and isinstance(tools_config, dict):
web_search_choice = tools_config.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search and web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools
def load_skill_config(
@@ -373,6 +368,86 @@ class AgentRunService:
)
return tools, bool(memory_config.get("enabled"))
@staticmethod
def _validate_file_upload(
features_config: Dict[str, Any],
files: Optional[List[FileInput]]
) -> None:
"""校验上传文件是否符合 file_upload 配置"""
if not files:
return
fu = features_config.get("file_upload", {})
if not (isinstance(fu, dict) and fu.get("enabled")):
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
max_count = fu.get("max_file_count", 5)
if len(files) > max_count:
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
# 校验传输方式
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
for f in files:
if str(f.transfer_method) not in allowed_methods:
raise BusinessException(
f"不支持的文件传输方式:{f.transfer_method},允许的方式:{', '.join(allowed_methods)}",
BizCode.BAD_REQUEST
)
# 各类型对应的开关和大小限制配置键
type_cfg = {
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
}
for f in files:
ftype = str(f.type) # 如 "image", "audio", "document", "video"
cfg = type_cfg.get(ftype)
if cfg is None:
continue
enabled_key, size_key, default_max_mb, label = cfg
# 校验类型开关
if not fu.get(enabled_key):
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
# 校验文件大小(仅当内容已加载时)
content = f.get_content()
if content is not None:
max_mb = fu.get(size_key, default_max_mb)
size_mb = len(content) / (1024 * 1024)
if size_mb > max_mb:
raise BusinessException(
f"{label}文件大小超过限制(最大 {max_mb}MB当前 {size_mb:.1f}MB",
BizCode.BAD_REQUEST
)
@staticmethod
def _inject_opening_statement(
features_config: Dict[str, Any],
system_prompt: str,
is_new_conversation: bool
) -> str:
"""首轮对话时将开场白注入 system_prompt"""
if not is_new_conversation:
return system_prompt
opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return system_prompt
statement = opening["statement"]
return f"{system_prompt}\n\n[对话开场白]\n{statement}"
@staticmethod
def _filter_citations(
features_config: Dict[str, Any],
citations: List[Any]
) -> List[Any]:
"""根据 citation 开关决定是否返回引用来源"""
citation_cfg = features_config.get("citation", {})
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
return citations
return []
async def run(
self,
*,
@@ -415,6 +490,15 @@ class AgentRunService:
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关(优先级高于参数默认值)
web_search_feature = features_config.get("web_search", {})
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
try:
# 1. 获取 API Key 配置
@@ -449,6 +533,10 @@ class AgentRunService:
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# opening_statement首轮对话注入开场白
is_new_conversation = not conversation_id
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
# 4. 准备工具列表
tools = []
@@ -491,12 +579,10 @@ class AgentRunService:
)
# 6. 加载历史消息
history = []
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
)
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=10
)
# 6. 处理多模态文件
processed_files = None
@@ -551,7 +637,7 @@ class AgentRunService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息
if not sub_agent and memory_config and memory_config.get("enabled"):
if not sub_agent:
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -575,7 +661,15 @@ class AgentRunService:
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
"elapsed_time": elapsed_time,
"suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [],
"citations": self._filter_citations(features_config, result.get("citations", [])),
"audio_url": await self._generate_tts(
features_config, result["content"], api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
) if not sub_agent else None,
}
logger.info(
@@ -630,6 +724,15 @@ class AgentRunService:
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
start_time = time.time()
@@ -659,6 +762,10 @@ class AgentRunService:
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# opening_statement首轮对话注入开场白
is_new_conversation = not conversation_id
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
# 4. 准备工具列表
tools = []
@@ -703,12 +810,10 @@ class AgentRunService:
)
# 6. 加载历史消息
history = []
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=memory_config.get("max_history", 10)
)
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=memory_config.get("max_history", 10)
)
# 6. 处理多模态文件
processed_files = None
@@ -774,7 +879,7 @@ class AgentRunService:
})
# 10. 保存会话消息
if not sub_agent and memory_config and memory_config.get("enabled"):
if not sub_agent:
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -786,12 +891,22 @@ class AgentRunService:
}
)
# 11. 发送结束事件
yield self._format_sse_event("end", {
# 11. 发送结束事件(包含 suggested_questions 和 tts
end_data: Dict[str, Any] = {
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"message_length": len(full_content)
})
}
if not sub_agent:
end_data["suggested_questions"] = await self._generate_suggested_questions(
features_config, full_content, api_key_config, effective_params
)
end_data["audio_url"] = await self._generate_tts(
features_config, full_content, api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
)
end_data["citations"] = self._filter_citations(features_config, [])
yield self._format_sse_event("end", end_data)
logger.info(
"流式试运行完成",
@@ -1137,6 +1252,165 @@ class AgentRunService:
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {}
async def _generate_suggested_questions(
self,
features_config: Dict[str, Any],
assistant_message: str,
api_key_config: Dict[str, Any],
effective_params: Dict[str, Any]
) -> List[str]:
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
sq_config = features_config.get("suggested_questions_after_answer", {})
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
return []
try:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
llm = ChatOpenAI(
model=api_key_config["model_name"],
api_key=api_key_config["api_key"],
base_url=api_key_config.get("api_base"),
temperature=0.5,
max_tokens=200,
)
prompt = (
f"根据以下AI回复生成3个用户可能继续追问的简短问题每行一个不加序号\n\n{assistant_message}"
)
resp = await llm.ainvoke([HumanMessage(content=prompt)])
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
return lines[:3]
except Exception as e:
logger.warning(f"生成建议问题失败: {e}")
return []
async def _generate_tts(
self,
features_config: Dict[str, Any],
text: str,
api_key_config: Dict[str, Any],
tenant_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
) -> Optional[str]:
"""根据 text_to_speech 配置生成语音,上传到存储并返回 URL"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
return None
if not text or not text.strip():
return None
try:
from app.services.file_storage_service import FileStorageService
provider = api_key_config.get("provider", "openai")
api_key = api_key_config.get("api_key")
api_base = api_key_config.get("api_base")
voice = tts_config.get("voice")
if provider == "dashscope":
audio_bytes, file_ext, content_type = await self._tts_dashscope(
api_key=api_key,
text=text,
voice=voice or "longxiaochun", # 会根据 model 版本自动修正后缀
tts_config=tts_config,
)
else:
# OpenAI 兼容接口openai / xinference / gpustack 等)
audio_bytes, file_ext, content_type = await self._tts_openai(
api_key=api_key,
api_base=api_base,
text=text,
voice=voice or "alloy",
)
storage_service = FileStorageService()
file_id = uuid.uuid4()
file_key = await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=audio_bytes,
content_type=content_type,
)
# 保存文件元数据到数据库
from app.models.file_metadata_model import FileMetadata
db_file = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=f"tts_{file_id}{file_ext}",
file_ext=file_ext,
file_size=len(audio_bytes),
content_type=content_type,
status="completed",
)
self.db.add(db_file)
self.db.commit()
server_url = settings.FILE_LOCAL_SERVER_URL
audio_url = f"{server_url}/storage/permanent/{file_id}"
logger.debug(f"TTS 生成成功provider={provider}, file_key={file_key}")
return audio_url
except Exception as e:
logger.warning(f"TTS 生成失败: {e}")
return None
@staticmethod
async def _tts_openai(
api_key: str,
api_base: Optional[str],
text: str,
voice: str,
) -> tuple:
"""OpenAI 兼容 TTS返回 (audio_bytes, file_ext, content_type)"""
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
response = await client.audio.speech.create(
model="tts-1",
voice=voice,
input=text[:4096],
)
return response.content, ".mp3", "audio/mpeg"
@staticmethod
async def _tts_dashscope(
api_key: str,
text: str,
voice: str,
tts_config: Dict[str, Any],
) -> tuple:
"""DashScope CosyVoice TTS返回 (audio_bytes, file_ext, content_type)"""
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
model = tts_config.get("model") or "cosyvoice-v2"
is_v2 = model.endswith("-v2")
# cosyvoice-v2 音色名带 _v2 后缀v1 不带
# 如果用户传入的 voice 不匹配当前模型版本,自动修正
if is_v2 and not voice.endswith("_v2"):
voice = voice + "_v2"
elif not is_v2 and voice.endswith("_v2"):
voice = voice[:-3] # 去掉 _v2
def _sync_call() -> bytes:
dashscope.api_key = api_key
synthesizer = SpeechSynthesizer(
model=model,
voice=voice,
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
)
audio = synthesizer.call(text[:4096])
if audio is None:
raise RuntimeError("DashScope TTS 返回空音频")
return audio
audio_bytes = await asyncio.to_thread(_sync_call)
return audio_bytes, ".mp3", "audio/mpeg"
def _replace_variables(
self,
text: str,
@@ -1221,6 +1495,12 @@ class AgentRunService:
}
)
# 提前校验文件上传(与 run() 内部保持一致)
features_config: dict = agent_config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
async def run_single_model(model_info):
"""运行单个模型"""
try:
@@ -1271,6 +1551,9 @@ class AgentRunService:
if elapsed > 0 and usage.get("completion_tokens") else None
),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"),
"citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []),
"error": None
}
@@ -1343,7 +1626,12 @@ class AgentRunService:
)
return {
"results": results,
"results": [{
**r,
"audio_url": r.get("audio_url"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
} for r in results],
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful),
"failed_count": len(failed),
@@ -1434,6 +1722,12 @@ class AgentRunService:
extra={"model_count": len(models), "parallel": parallel}
)
# 提前校验文件上传
# features_config: dict = agent_config.features or {}
# if hasattr(features_config, 'model_dump'):
# features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
# 发送开始事件
yield self._format_sse_event("compare_start", {
"conversation_id": conversation_id,
@@ -1465,6 +1759,9 @@ class AgentRunService:
start_time = time.time()
full_content = ""
returned_conversation_id = model_conversation_id
audio_url = None
citations = []
suggested_questions = []
# 临时修改参数
original_params = agent_config.model_parameters
@@ -1518,6 +1815,12 @@ class AgentRunService:
"content": chunk
}))
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", [])
if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
@@ -1543,6 +1846,9 @@ class AgentRunService:
"parameters_used": model_info["parameters"],
"message": full_content,
"elapsed_time": elapsed,
"audio_url": audio_url,
"citations": citations,
"suggested_questions": suggested_questions,
"error": None
}
@@ -1554,6 +1860,9 @@ class AgentRunService:
"conversation_id": returned_conversation_id,
"elapsed_time": elapsed,
"message_length": len(full_content),
"audio_url": audio_url,
"citations": citations,
"suggested_questions": suggested_questions,
"timestamp": time.time()
}))
@@ -1685,8 +1994,11 @@ class AgentRunService:
"model_name": r["model_name"],
"label": r["label"],
"conversation_id": r.get("conversation_id"),
"message": r.get("message"), # 包含完整消息
"message": r.get("message"),
"elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
"error": r.get("error")
})

View File

@@ -68,14 +68,14 @@ def get_workspace_end_users(
return []
# 提取所有 app_id
app_ids = [app.id for app in apps_orm]
# app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
EndUserModel.workspace_id == workspace_id
).order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id)

View File

@@ -78,7 +78,7 @@ class ToolService:
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
"""获取工具详情"""
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
config = self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
return self._config_to_info(config) if config else None
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
@@ -237,7 +237,7 @@ class ToolService:
return False
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
"""删除工具"""
"""删除工具(逻辑删除)"""
config = self._get_tool_config(tool_id, tenant_id)
if not config:
return False
@@ -246,14 +246,7 @@ class ToolService:
raise ValueError("内置工具不允许删除")
try:
# 删除关联表记录
if config.tool_type == ToolType.CUSTOM.value:
self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete()
elif config.tool_type == ToolType.MCP.value:
self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete()
# 删除主表记录ToolExecution会通过cascade自动删除
self.db.delete(config)
config.is_active = False
self._clear_tool_cache(tool_id)
self.db.commit()
return True
@@ -262,6 +255,27 @@ class ToolService:
logger.error(f"删除工具失败: {tool_id}, {e}")
return False
def set_tool_active(self, tool_id: str, tenant_id: uuid.UUID, is_active: bool) -> bool:
"""设置工具可用状态(启用/禁用)"""
# 直接查询,包含 is_active=False 的记录
config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id),
ToolConfig.tenant_id == tenant_id
).first()
if not config:
return False
if config.tool_type == ToolType.BUILTIN.value:
raise ValueError("内置工具不允许修改可用状态")
try:
config.is_active = is_active
self._clear_tool_cache(tool_id)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
logger.error(f"设置工具状态失败: {tool_id}, {e}")
return False
async def execute_tool(
self,
tool_id: str,
@@ -378,7 +392,7 @@ class ToolService:
Returns:
方法列表或None
"""
config = self._get_tool_config(tool_id, tenant_id)
config = self._get_tool_config_all(tool_id, tenant_id)
if not config:
return None
@@ -857,16 +871,20 @@ class ToolService:
}
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""获取工具配置"""
"""获取工具配置(仅返回 is_active=True)"""
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
def _get_tool_config_all(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""获取工具配置(返回所有)"""
return self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
"""获取工具实例"""
"""获取工具实例(仅返回 is_active=True 的工具)"""
if tool_id in self._tool_cache:
return self._tool_cache[tool_id]
config = self._get_tool_config(tool_id, tenant_id)
if not config:
if not config or not config.is_active:
return None
try:
@@ -980,6 +998,7 @@ class ToolService:
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None,
config_data=config_data,
is_active=config.is_active,
created_at=config.created_at
)