feat(agent): add input variable validation
This commit is contained in:
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
@@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel):
|
||||
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||
|
||||
|
||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None):
|
||||
def create_long_term_memory_tool(
|
||||
memory_config: Dict[str, Any],
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
):
|
||||
"""创建记忆工具,
|
||||
|
||||
|
||||
@@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
memory_config: 记忆配置
|
||||
end_user_id: 用户ID
|
||||
storage_type: 存储类型(可选)
|
||||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||
|
||||
Returns:
|
||||
长期记忆工具
|
||||
@@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||
|
||||
Args:
|
||||
query: 需要检索的问题或关键词
|
||||
kb_config: 知识库配置
|
||||
kb_ids: 知识库ID列表
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
检索到的相关知识内容
|
||||
@@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
return knowledge_retrieval_tool
|
||||
|
||||
|
||||
class DraftRunService:
|
||||
"""试运行服务类"""
|
||||
class AgentRunService:
|
||||
"""Agent运行服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化试运行服务
|
||||
"""Agent运行服务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def prepare_variables(
|
||||
input_vars: dict | None,
|
||||
variables_config: dict | None
|
||||
) -> dict:
|
||||
input_vars = input_vars or {}
|
||||
for variable in variables_config:
|
||||
if variable.get("required") and variable.get("name") not in input_vars:
|
||||
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
|
||||
return input_vars
|
||||
|
||||
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||||
"""加载工具配置"""
|
||||
if not tools_config:
|
||||
return []
|
||||
tools = []
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if tools_config and isinstance(tools_config, list):
|
||||
for tool_config in tools_config:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif tools_config and isinstance(tools_config, dict):
|
||||
web_search_choice = tools_config.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search and web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_skill_config(
|
||||
self,
|
||||
skills_config: dict | None,
|
||||
message: str, tenant_id
|
||||
) -> tuple[list, str]:
|
||||
if not skills_config:
|
||||
return [], ""
|
||||
|
||||
tools = []
|
||||
skill_prompts = ""
|
||||
skill_enable = skills_config.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills_config)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
skill_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
|
||||
return tools, skill_prompts
|
||||
|
||||
def load_knowledge_retrieval_config(
|
||||
self,
|
||||
knowledge_retrieval_config: dict | None,
|
||||
user_id
|
||||
) -> list:
|
||||
if not knowledge_retrieval_config:
|
||||
return []
|
||||
|
||||
tools = []
|
||||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
memory_config: dict | None,
|
||||
user_id,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
) -> tuple[list, bool]:
|
||||
"""加载长期记忆配置"""
|
||||
if not memory_config:
|
||||
return [], False
|
||||
|
||||
tools = []
|
||||
if memory_config.get("enabled"):
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools, bool(memory_config.get("enabled"))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
@@ -270,19 +403,21 @@ class DraftRunService:
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID
|
||||
variables: 自定义变量参数值
|
||||
storage_type: 存储类型(可选)
|
||||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||
web_search: 是否启用网络搜索(默认True)
|
||||
memory: 是否启用长期记忆(默认True)
|
||||
sub_agent: 是否为子代理调用(默认False)
|
||||
files: 多模态文件列表(可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
"""
|
||||
memory_flag = False
|
||||
|
||||
print('===========', storage_type)
|
||||
|
||||
print(user_id)
|
||||
if variables == None: variables = {}
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
|
||||
start_time = time.time()
|
||||
tools_config: dict | list | None = agent_config.tools
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
@@ -302,112 +437,40 @@ class DraftRunService:
|
||||
agent_config=agent_config
|
||||
)
|
||||
|
||||
items_params = variables
|
||||
if sub_agent:
|
||||
variables = self.prepare_variables(variables, agent_config.variables)
|
||||
else:
|
||||
# FIXME: subagent input valid
|
||||
variables = variables or {}
|
||||
|
||||
system_prompt = render_prompt_message(
|
||||
agent_config.system_prompt, # 修正拼写错误
|
||||
agent_config.system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
items_params
|
||||
variables
|
||||
)
|
||||
|
||||
# 3. 处理系统提示词(支持变量替换)
|
||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||
print('系统提示词:', system_prompt)
|
||||
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_config in agent_config.tools:
|
||||
print("+" * 50)
|
||||
print(f"agent_config:{agent_config}")
|
||||
print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
memory_flag = True
|
||||
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
memory_tools, memory_flag = self.load_memory_config(
|
||||
memory_config, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
@@ -432,7 +495,7 @@ class DraftRunService:
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
@@ -482,7 +545,7 @@ class DraftRunService:
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -557,16 +620,21 @@ class DraftRunService:
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
"""
|
||||
memory_flag = False
|
||||
if variables == None: variables = {}
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
tools_config: dict | list | None = agent_config.tools
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
api_key_config = await self._get_api_key(model_config.id)
|
||||
if not sub_agent:
|
||||
variables = self.prepare_variables(variables, agent_config.variables)
|
||||
else:
|
||||
# FIXME: subagent input valid
|
||||
variables = variables or {}
|
||||
|
||||
# 2. 合并模型参数
|
||||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||||
@@ -588,95 +656,22 @@ class DraftRunService:
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
for tool_config in agent_config.tools:
|
||||
# print("+"*50)
|
||||
# print(f"agent_config:{agent_config}")
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
memory_flag = True
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
@@ -702,10 +697,10 @@ class DraftRunService:
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
@@ -763,7 +758,7 @@ class DraftRunService:
|
||||
})
|
||||
|
||||
# 10. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -969,7 +964,6 @@ class DraftRunService:
|
||||
List[Dict]: 历史消息列表
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
history = conversation_service.get_conversation_history(
|
||||
@@ -1489,6 +1483,15 @@ class DraftRunService:
|
||||
"conversation_id": returned_conversation_id,
|
||||
"content": chunk
|
||||
}))
|
||||
|
||||
if event_type == "error" and event_data:
|
||||
await event_queue.put(self._format_sse_event("model_error", {
|
||||
"model_index": idx,
|
||||
"model_config_id": model_config_id,
|
||||
"label": model_label,
|
||||
"conversation_id": returned_conversation_id,
|
||||
"error": event_data.get("error", "未知错误")
|
||||
}))
|
||||
except Exception as e:
|
||||
logger.warning(f"解析流式事件失败: {e}")
|
||||
finally:
|
||||
@@ -1673,41 +1676,3 @@ class DraftRunService:
|
||||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def draft_run(
|
||||
db: Session,
|
||||
*,
|
||||
agent_config: AgentConfig,
|
||||
model_config: ModelConfig,
|
||||
message: str,
|
||||
user_id: Optional[str] = None,
|
||||
kb_ids: Optional[List[str]] = None,
|
||||
similarity_threshold: float = 0.7,
|
||||
top_k: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(便捷函数)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
agent_config: Agent 配置
|
||||
model_config: 模型配置
|
||||
message: 用户消息
|
||||
user_id: 用户ID
|
||||
kb_ids: 知识库ID列表
|
||||
similarity_threshold: 相似度阈值
|
||||
top_k: 检索返回的文档数量
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
"""
|
||||
service = DraftRunService(db)
|
||||
return await service.run(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
kb_ids=kb_ids,
|
||||
similarity_threshold=similarity_threshold,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user