feat(agent): add input variable validation
This commit is contained in:
@@ -396,10 +396,10 @@ async def draft_run(
|
|||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
service = AppService(db)
|
service = AppService(db)
|
||||||
draft_service = DraftRunService(db)
|
draft_service = AgentRunService(db)
|
||||||
|
|
||||||
# 1. 验证应用
|
# 1. 验证应用
|
||||||
app = service._get_app_or_404(app_id)
|
app = service._get_app_or_404(app_id)
|
||||||
@@ -484,8 +484,8 @@ async def draft_run(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
draft_service = DraftRunService(db)
|
draft_service = AgentRunService(db)
|
||||||
result = await draft_service.run(
|
result = await draft_service.run(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -789,8 +789,8 @@ async def draft_run_compare(
|
|||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
draft_service = DraftRunService(db)
|
draft_service = AgentRunService(db)
|
||||||
async for event in draft_service.run_compare_stream(
|
async for event in draft_service.run_compare_stream(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
models=model_configs,
|
models=model_configs,
|
||||||
@@ -820,8 +820,8 @@ async def draft_run_compare(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 非流式返回
|
# 非流式返回
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
draft_service = DraftRunService(db)
|
draft_service = AgentRunService(db)
|
||||||
result = await draft_service.run_compare(
|
result = await draft_service.run_compare(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
models=model_configs,
|
models=model_configs,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class RedBearModelConfig(BaseModel):
|
class RedBearModelConfig(BaseModel):
|
||||||
"""模型配置基类"""
|
"""模型配置基类"""
|
||||||
model_name: str
|
model_name: str
|
||||||
@@ -32,9 +33,10 @@ class RedBearModelConfig(BaseModel):
|
|||||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||||
concurrency: int = 5 # 并发限流
|
concurrency: int = 5 # 并发限流
|
||||||
extra_params: Dict[str, Any] = {}
|
extra_params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class RedBearModelFactory:
|
class RedBearModelFactory:
|
||||||
"""模型工厂类"""
|
"""模型工厂类"""
|
||||||
|
|
||||||
@@ -87,7 +89,7 @@ class RedBearModelFactory:
|
|||||||
"timeout": timeout_config,
|
"timeout": timeout_config,
|
||||||
"max_retries": config.max_retries,
|
"max_retries": config.max_retries,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
# DashScope (通义千问) 使用自己的参数格式
|
# DashScope (通义千问) 使用自己的参数格式
|
||||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||||
@@ -145,16 +147,17 @@ class RedBearModelFactory:
|
|||||||
"""根据提供商获取模型参数"""
|
"""根据提供商获取模型参数"""
|
||||||
provider = config.provider.lower()
|
provider = config.provider.lower()
|
||||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||||
return {
|
return {
|
||||||
"model": config.model_name,
|
"model": config.model_name,
|
||||||
# "base_url": config.base_url,
|
# "base_url": config.base_url,
|
||||||
"jina_api_key": config.api_key,
|
"jina_api_key": config.api_key,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
|
|
||||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
|
||||||
|
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
|
||||||
"""根据模型提供商获取对应的模型类"""
|
"""根据模型提供商获取对应的模型类"""
|
||||||
provider = config.provider.lower()
|
provider = config.provider.lower()
|
||||||
|
|
||||||
@@ -183,10 +186,11 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
|
|||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
|
||||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||||
"""根据模型提供商获取对应的模型类"""
|
"""根据模型提供商获取对应的模型类"""
|
||||||
provider = provider.lower()
|
provider = provider.lower()
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
return OpenAIEmbeddings
|
return OpenAIEmbeddings
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
@@ -201,13 +205,14 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
|||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
|
||||||
def get_provider_rerank_class(provider: str):
|
def get_provider_rerank_class(provider: str):
|
||||||
"""根据模型提供商获取对应的模型类"""
|
"""根据模型提供商获取对应的模型类"""
|
||||||
provider = provider.lower()
|
provider = provider.lower()
|
||||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||||
from langchain_community.document_compressors import JinaRerank
|
from langchain_community.document_compressors import JinaRerank
|
||||||
return JinaRerank
|
return JinaRerank
|
||||||
# elif provider == ModelProvider.OLLAMA:
|
# elif provider == ModelProvider.OLLAMA:
|
||||||
# from langchain_ollama import OllamaEmbeddings
|
# from langchain_ollama import OllamaEmbeddings
|
||||||
# return OllamaEmbeddings
|
# return OllamaEmbeddings
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
|||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import AppRelease
|
from app.models import AppRelease
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
||||||
"""准备 Agent(公共逻辑)
|
"""准备 Agent(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -65,7 +65,7 @@ class AgentNode(BaseNode):
|
|||||||
if not release:
|
if not release:
|
||||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
draft_service = DraftRunService(db)
|
draft_service = AgentRunService(db)
|
||||||
|
|
||||||
return draft_service, release, message
|
return draft_service, release, message
|
||||||
|
|
||||||
|
|||||||
@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
|
|||||||
return datetime.datetime.now() > self.expires_at
|
return datetime.datetime.now() > self.expires_at
|
||||||
|
|
||||||
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
||||||
@classmethod
|
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
|
||||||
"""将datetime转换为时间戳"""
|
"""将datetime转换为时间戳"""
|
||||||
return datetime_to_timestamp(v)
|
return datetime_to_timestamp(v)
|
||||||
|
|
||||||
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
|
|||||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||||
|
|
||||||
@field_serializer('last_used_at')
|
@field_serializer('last_used_at')
|
||||||
@classmethod
|
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
|
||||||
"""将datetime转换为时间戳"""
|
"""将datetime转换为时间戳"""
|
||||||
return datetime_to_timestamp(v)
|
return datetime_to_timestamp(v)
|
||||||
|
|
||||||
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
|
|||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|
||||||
@field_serializer('created_at')
|
@field_serializer('created_at')
|
||||||
@classmethod
|
def serialize_datetime(self, v: datetime.datetime) -> int:
|
||||||
def serialize_datetime(cls, v: datetime.datetime) -> int:
|
|
||||||
"""将datetime转换为时间戳"""
|
"""将datetime转换为时间戳"""
|
||||||
return datetime_to_timestamp(v)
|
return datetime_to_timestamp(v)
|
||||||
|
|||||||
@@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel):
|
|||||||
class MultiAgentConfigCreate(BaseModel):
|
class MultiAgentConfigCreate(BaseModel):
|
||||||
"""创建多 Agent 配置"""
|
"""创建多 Agent 配置"""
|
||||||
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
|
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
|
||||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
|
||||||
orchestration_mode: str = Field(
|
orchestration_mode: str = Field(
|
||||||
default="collaboration",
|
default="collaboration",
|
||||||
pattern="^(collaboration|supervisor)$",
|
pattern="^(collaboration|supervisor)$",
|
||||||
description="协作模式:collaboration(协作)| supervisor(监督)"
|
description="协作模式:collaboration(协作)| supervisor(监督)"
|
||||||
)
|
)
|
||||||
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
|
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
|
||||||
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
|
routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则")
|
||||||
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||||
aggregation_strategy: str = Field(
|
aggregation_strategy: str = Field(
|
||||||
default="merge",
|
default="merge",
|
||||||
@@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel):
|
|||||||
class MultiAgentConfigUpdate(BaseModel):
|
class MultiAgentConfigUpdate(BaseModel):
|
||||||
"""更新多 Agent 配置"""
|
"""更新多 Agent 配置"""
|
||||||
master_agent_id: Optional[uuid.UUID] = None
|
master_agent_id: Optional[uuid.UUID] = None
|
||||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
|
||||||
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
|
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
|
||||||
model_parameters: Optional[ModelParameters] = Field(
|
model_parameters: Optional[ModelParameters] = Field(
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -263,8 +263,8 @@ def create_agent_invocation_tool(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 9. 调用 Agent
|
# 9. 调用 Agent
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
draft_service = DraftRunService(db)
|
draft_service = AgentRunService(db)
|
||||||
|
|
||||||
result = await draft_service.run(
|
result = await draft_service.run(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
|
|||||||
@@ -10,25 +10,24 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.agent.agent_middleware import AgentMiddleware
|
from app.core.agent.agent_middleware import AgentMiddleware
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
|
||||||
from app.schemas import DraftRunRequest
|
|
||||||
from app.schemas.app_schema import FileInput
|
|
||||||
from app.services.tool_service import ToolService
|
|
||||||
from app.repositories.tool_repository import ToolRepository
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import MultiAgentConfig, AgentConfig
|
from app.models import MultiAgentConfig, AgentConfig
|
||||||
|
from app.models import WorkflowConfig
|
||||||
|
from app.repositories.tool_repository import ToolRepository
|
||||||
|
from app.schemas import DraftRunRequest
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||||
|
AgentRunService
|
||||||
from app.services.draft_run_service import create_web_search_tool
|
from app.services.draft_run_service import create_web_search_tool
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.workflow_service import WorkflowService
|
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
from app.services.tool_service import ToolService
|
||||||
|
from app.services.workflow_service import WorkflowService
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -39,6 +38,8 @@ class AppChatService:
|
|||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.conversation_service = ConversationService(db)
|
self.conversation_service = ConversationService(db)
|
||||||
|
self.agent_service = AgentRunService(db)
|
||||||
|
self.workflow_service = WorkflowService(db)
|
||||||
|
|
||||||
async def agnet_chat(
|
async def agnet_chat(
|
||||||
self,
|
self,
|
||||||
@@ -55,12 +56,10 @@ class AppChatService:
|
|||||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
config_id = None
|
||||||
|
|
||||||
if variables is None:
|
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||||
variables = {}
|
|
||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
@@ -79,74 +78,20 @@ class AppChatService:
|
|||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# 获取工具服务
|
# 获取工具服务
|
||||||
tool_service = ToolService(self.db)
|
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
|
||||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
|
||||||
for tool_config in config.tools:
|
tools.extend(skill_tools)
|
||||||
if tool_config.get("enabled", False):
|
if skill_prompts:
|
||||||
# 根据工具名称查找工具实例
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||||
if tool_instance:
|
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
|
||||||
continue
|
|
||||||
# 转换为LangChain工具
|
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
|
||||||
tools.append(langchain_tool)
|
|
||||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
|
||||||
web_tools = config.tools
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search:
|
|
||||||
if web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载技能关联的工具
|
|
||||||
if hasattr(config, 'skills') and config.skills:
|
|
||||||
skills = config.skills
|
|
||||||
skill_enable = skills.get("enabled", False)
|
|
||||||
if skill_enable:
|
|
||||||
middleware = AgentMiddleware(skills=skills)
|
|
||||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
|
||||||
tools.extend(skill_tools)
|
|
||||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
|
||||||
|
|
||||||
# 应用动态过滤
|
|
||||||
if skill_configs:
|
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
|
||||||
tool_to_skill_map)
|
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
|
||||||
activated_skill_ids, skill_configs
|
|
||||||
)
|
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
|
||||||
knowledge_retrieval = config.knowledge_retrieval
|
|
||||||
if knowledge_retrieval:
|
|
||||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
|
||||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
|
||||||
if kb_ids:
|
|
||||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
|
||||||
tools.append(kb_tool)
|
|
||||||
|
|
||||||
# 添加长期记忆工具
|
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory == True:
|
if memory:
|
||||||
memory_config = config.memory
|
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||||
if memory_config.get("enabled") and user_id:
|
config.memory, user_id, storage_type, user_rag_memory_id
|
||||||
memory_flag = True
|
)
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
tools.extend(memory_tools)
|
||||||
tools.append(memory_tool)
|
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
@@ -246,10 +191,9 @@ class AppChatService:
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
config_id = None
|
||||||
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
if variables is None:
|
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||||
variables = {}
|
|
||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||||
@@ -267,73 +211,22 @@ class AppChatService:
|
|||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# 获取工具服务
|
# 获取工具服务
|
||||||
tool_service = ToolService(self.db)
|
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
|
||||||
for tool_config in config.tools:
|
|
||||||
if tool_config.get("enabled", False):
|
|
||||||
# 根据工具名称查找工具实例
|
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
|
||||||
if tool_instance:
|
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
|
||||||
continue
|
|
||||||
# 转换为LangChain工具
|
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
|
||||||
tools.append(langchain_tool)
|
|
||||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
|
||||||
web_tools = config.tools
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search:
|
|
||||||
if web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载技能关联的工具
|
|
||||||
if hasattr(config, 'skills') and config.skills:
|
|
||||||
skills = config.skills
|
|
||||||
skill_enable = skills.get("enabled", False)
|
|
||||||
if skill_enable:
|
|
||||||
middleware = AgentMiddleware(skills=skills)
|
|
||||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
|
||||||
tools.extend(skill_tools)
|
|
||||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
|
||||||
|
|
||||||
# 应用动态过滤
|
|
||||||
if skill_configs:
|
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
|
||||||
tool_to_skill_map)
|
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
|
||||||
activated_skill_ids, skill_configs
|
|
||||||
)
|
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
|
||||||
knowledge_retrieval = config.knowledge_retrieval
|
|
||||||
if knowledge_retrieval:
|
|
||||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
|
||||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
|
||||||
if kb_ids:
|
|
||||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
|
||||||
tools.append(kb_tool)
|
|
||||||
|
|
||||||
|
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
if skill_prompts:
|
||||||
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
|
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
memory_config = config.memory
|
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||||
if memory_config.get("enabled") and user_id:
|
config.memory, user_id, storage_type, user_rag_memory_id
|
||||||
memory_flag = True
|
)
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
tools.extend(memory_tools)
|
||||||
tools.append(memory_tool)
|
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
@@ -372,9 +265,6 @@ class AppChatService:
|
|||||||
processed_files = await multimodal_service.process_files(files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 发送开始事件
|
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# 流式调用 Agent(支持多模态)
|
# 流式调用 Agent(支持多模态)
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
@@ -418,7 +308,7 @@ class AppChatService:
|
|||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||||
|
|
||||||
# 发送结束事件
|
# 发送结束事件
|
||||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -437,7 +327,7 @@ class AppChatService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||||
# 发送错误事件
|
# 发送错误事件
|
||||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
async def multi_agent_chat(
|
async def multi_agent_chat(
|
||||||
self,
|
self,
|
||||||
@@ -491,10 +381,10 @@ class AppChatService:
|
|||||||
"mode": result.get("mode"),
|
"mode": result.get("mode"),
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
"usage": result.get("usage", {
|
"usage": result.get("usage", {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -524,8 +414,6 @@ class AppChatService:
|
|||||||
"""多 Agent 聊天(流式)"""
|
"""多 Agent 聊天(流式)"""
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
actual_config_id = None
|
|
||||||
config_id = actual_config_id
|
|
||||||
|
|
||||||
if variables is None:
|
if variables is None:
|
||||||
variables = {}
|
variables = {}
|
||||||
@@ -631,7 +519,6 @@ class AppChatService:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
workflow_service = WorkflowService(self.db)
|
|
||||||
payload = DraftRunRequest(
|
payload = DraftRunRequest(
|
||||||
message=message,
|
message=message,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
@@ -639,7 +526,7 @@ class AppChatService:
|
|||||||
stream=True,
|
stream=True,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
return await workflow_service.run(
|
return await self.workflow_service.run(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -666,7 +553,6 @@ class AppChatService:
|
|||||||
|
|
||||||
) -> AsyncGenerator[dict, None]:
|
) -> AsyncGenerator[dict, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
workflow_service = WorkflowService(self.db)
|
|
||||||
payload = DraftRunRequest(
|
payload = DraftRunRequest(
|
||||||
message=message,
|
message=message,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
@@ -675,7 +561,7 @@ class AppChatService:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
files=files
|
files=files
|
||||||
)
|
)
|
||||||
async for event in workflow_service.run_stream(
|
async for event in self.workflow_service.run_stream(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
@@ -1791,372 +1791,6 @@ class AppService:
|
|||||||
|
|
||||||
return shares
|
return shares
|
||||||
|
|
||||||
# ==================== 试运行功能 ====================
|
|
||||||
|
|
||||||
async def draft_run(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
message: str,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
|
||||||
workspace_id: Optional[uuid.UUID] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""试运行 Agent(使用当前草稿配置)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
message: 用户消息
|
|
||||||
conversation_id: 会话ID(用于多轮对话)
|
|
||||||
user_id: 用户ID(用于会话管理)
|
|
||||||
variables: 自定义变量参数值
|
|
||||||
workspace_id: 工作空间ID(用于权限验证)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 包含 AI 回复和元数据的字典
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ResourceNotFoundException: 当应用不存在时
|
|
||||||
BusinessException: 当应用类型不支持或配置缺失时
|
|
||||||
"""
|
|
||||||
from app.services.draft_run_service import DraftRunService
|
|
||||||
|
|
||||||
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
|
||||||
|
|
||||||
# 1. 验证应用
|
|
||||||
app = self._get_app_or_404(app_id)
|
|
||||||
|
|
||||||
if app.type != "agent":
|
|
||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
|
||||||
|
|
||||||
# 只读操作,允许访问共享应用
|
|
||||||
self._validate_app_accessible(app, workspace_id)
|
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
|
||||||
agent_cfg = self.db.scalars(stmt).first()
|
|
||||||
|
|
||||||
if not agent_cfg:
|
|
||||||
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
|
|
||||||
# 3. 获取模型配置
|
|
||||||
model_config = None
|
|
||||||
if agent_cfg.default_model_config_id:
|
|
||||||
from app.models import ModelConfig
|
|
||||||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
|
||||||
|
|
||||||
if not model_config:
|
|
||||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
|
|
||||||
# 4. 调用试运行服务
|
|
||||||
logger.debug(
|
|
||||||
"准备调用试运行服务",
|
|
||||||
extra={
|
|
||||||
"app_id": str(app_id),
|
|
||||||
"model": model_config.name,
|
|
||||||
"has_conversation_id": bool(conversation_id),
|
|
||||||
"has_variables": bool(variables)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_service = DraftRunService(self.db)
|
|
||||||
result = await draft_service.run(
|
|
||||||
agent_config=agent_cfg,
|
|
||||||
model_config=model_config,
|
|
||||||
message=message,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=user_id,
|
|
||||||
variables=variables
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"试运行服务返回结果",
|
|
||||||
extra={
|
|
||||||
"result_type": str(type(result)),
|
|
||||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
|
|
||||||
"has_message": "message" in result if isinstance(result, dict) else False,
|
|
||||||
"has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"试运行完成",
|
|
||||||
extra={
|
|
||||||
"app_id": str(app_id),
|
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
|
||||||
"model": model_config.name
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def draft_run_stream(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
message: str,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
|
||||||
workspace_id: Optional[uuid.UUID] = None
|
|
||||||
):
|
|
||||||
"""试运行 Agent(流式返回)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
message: 用户消息
|
|
||||||
conversation_id: 会话ID(用于多轮对话)
|
|
||||||
user_id: 用户ID(用于会话管理)
|
|
||||||
variables: 自定义变量参数值
|
|
||||||
workspace_id: 工作空间ID(用于权限验证)
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
str: SSE 格式的事件数据
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ResourceNotFoundException: 当应用不存在时
|
|
||||||
BusinessException: 当应用类型不支持或配置缺失时
|
|
||||||
"""
|
|
||||||
from app.services.draft_run_service import DraftRunService
|
|
||||||
|
|
||||||
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
|
||||||
|
|
||||||
# 1. 验证应用
|
|
||||||
app = self._get_app_or_404(app_id)
|
|
||||||
|
|
||||||
if app.type != "agent":
|
|
||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
|
||||||
|
|
||||||
# 只读操作,允许访问共享应用
|
|
||||||
self._validate_app_accessible(app, workspace_id)
|
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
|
||||||
agent_cfg = self.db.scalars(stmt).first()
|
|
||||||
|
|
||||||
if not agent_cfg:
|
|
||||||
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
|
|
||||||
# 3. 获取模型配置
|
|
||||||
model_config = None
|
|
||||||
if agent_cfg.default_model_config_id:
|
|
||||||
from app.models import ModelConfig
|
|
||||||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
|
||||||
|
|
||||||
if not model_config:
|
|
||||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
|
|
||||||
# 4. 调用流式试运行服务
|
|
||||||
draft_service = DraftRunService(self.db)
|
|
||||||
async for event in draft_service.run_stream(
|
|
||||||
agent_config=agent_cfg,
|
|
||||||
model_config=model_config,
|
|
||||||
message=message,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=user_id,
|
|
||||||
variables=variables
|
|
||||||
):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
# ==================== 多模型对比试运行 ====================
|
|
||||||
|
|
||||||
async def draft_run_compare(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
message: str,
|
|
||||||
models: List[app_schema.ModelCompareItem],
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
|
||||||
workspace_id: Optional[uuid.UUID] = None,
|
|
||||||
parallel: bool = True,
|
|
||||||
timeout: int = 60
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""多模型对比试运行
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
message: 用户消息
|
|
||||||
models: 要对比的模型列表
|
|
||||||
conversation_id: 会话ID
|
|
||||||
user_id: 用户ID
|
|
||||||
variables: 变量参数
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
parallel: 是否并行执行
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 对比结果
|
|
||||||
"""
|
|
||||||
from app.models import ModelConfig
|
|
||||||
from app.services.draft_run_service import DraftRunService
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"多模型对比试运行",
|
|
||||||
extra={
|
|
||||||
"app_id": str(app_id),
|
|
||||||
"model_count": len(models),
|
|
||||||
"parallel": parallel
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 验证应用
|
|
||||||
app = self._get_app_or_404(app_id)
|
|
||||||
if app.type != "agent":
|
|
||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
|
||||||
|
|
||||||
# 只读操作,允许访问共享应用
|
|
||||||
self._validate_app_accessible(app, workspace_id)
|
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
|
||||||
agent_cfg = self.db.scalars(stmt).first()
|
|
||||||
if not agent_cfg:
|
|
||||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
|
|
||||||
# 3. 准备所有模型配置
|
|
||||||
model_configs = []
|
|
||||||
for model_item in models:
|
|
||||||
model_config = self.db.get(ModelConfig, model_item.model_config_id)
|
|
||||||
if not model_config:
|
|
||||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
|
||||||
|
|
||||||
# 合并参数:agent配置参数 + 请求覆盖参数
|
|
||||||
merged_parameters = {
|
|
||||||
**(agent_cfg.model_parameters or {}),
|
|
||||||
**(model_item.model_parameters or {})
|
|
||||||
}
|
|
||||||
|
|
||||||
model_configs.append({
|
|
||||||
"model_config": model_config,
|
|
||||||
"parameters": merged_parameters,
|
|
||||||
"label": model_item.label or model_config.name,
|
|
||||||
"model_config_id": model_item.model_config_id
|
|
||||||
})
|
|
||||||
|
|
||||||
# 4. 调用 DraftRunService 的对比方法
|
|
||||||
draft_service = DraftRunService(self.db)
|
|
||||||
result = await draft_service.run_compare(
|
|
||||||
agent_config=agent_cfg,
|
|
||||||
models=model_configs,
|
|
||||||
message=message,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=user_id,
|
|
||||||
variables=variables,
|
|
||||||
parallel=parallel,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"多模型对比完成",
|
|
||||||
extra={
|
|
||||||
"app_id": str(app_id),
|
|
||||||
"successful": result["successful_count"],
|
|
||||||
"failed": result["failed_count"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def draft_run_compare_stream(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
message: str,
|
|
||||||
models: List[app_schema.ModelCompareItem],
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
|
||||||
workspace_id: Optional[uuid.UUID] = None,
|
|
||||||
parallel: bool = True,
|
|
||||||
timeout: int = 60
|
|
||||||
):
|
|
||||||
"""多模型对比试运行(流式返回)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
message: 用户消息
|
|
||||||
models: 要对比的模型列表
|
|
||||||
conversation_id: 会话ID
|
|
||||||
user_id: 用户ID
|
|
||||||
variables: 变量参数
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
str: SSE 格式的事件数据
|
|
||||||
"""
|
|
||||||
from app.models import ModelConfig
|
|
||||||
from app.services.draft_run_service import DraftRunService
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"多模型对比流式试运行",
|
|
||||||
extra={
|
|
||||||
"app_id": str(app_id),
|
|
||||||
"model_count": len(models)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 验证应用
|
|
||||||
app = self._get_app_or_404(app_id)
|
|
||||||
if app.type != "agent":
|
|
||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
|
||||||
|
|
||||||
# 只读操作,允许访问共享应用
|
|
||||||
self._validate_app_accessible(app, workspace_id)
|
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
|
||||||
agent_cfg = self.db.scalars(stmt).first()
|
|
||||||
if not agent_cfg:
|
|
||||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
|
||||||
|
|
||||||
# 3. 准备所有模型配置
|
|
||||||
model_configs = []
|
|
||||||
for model_item in models:
|
|
||||||
model_config = self.db.get(ModelConfig, model_item.model_config_id)
|
|
||||||
if not model_config:
|
|
||||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
|
||||||
|
|
||||||
# 合并参数:agent配置参数 + 请求覆盖参数
|
|
||||||
merged_parameters = {
|
|
||||||
**(agent_cfg.model_parameters or {}),
|
|
||||||
**(model_item.model_parameters or {})
|
|
||||||
}
|
|
||||||
|
|
||||||
model_configs.append({
|
|
||||||
"model_config": model_config,
|
|
||||||
"parameters": merged_parameters,
|
|
||||||
"label": model_item.label or model_config.name,
|
|
||||||
"model_config_id": model_item.model_config_id
|
|
||||||
})
|
|
||||||
|
|
||||||
# 4. 调用 DraftRunService 的流式对比方法
|
|
||||||
draft_service = DraftRunService(self.db)
|
|
||||||
async for event in draft_service.run_compare_stream(
|
|
||||||
agent_config=agent_cfg,
|
|
||||||
models=model_configs,
|
|
||||||
message=message,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=user_id,
|
|
||||||
variables=variables,
|
|
||||||
parallel=parallel,
|
|
||||||
timeout=timeout
|
|
||||||
):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"多模型对比流式完成",
|
|
||||||
extra={"app_id": str(app_id)}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 向后兼容的函数接口 ====================
|
# ==================== 向后兼容的函数接口 ====================
|
||||||
# 保留函数接口以兼容现有代码,但内部使用服务类
|
# 保留函数接口以兼容现有代码,但内部使用服务类
|
||||||
|
|
||||||
@@ -2278,53 +1912,6 @@ def get_apps_by_ids(
|
|||||||
return service.get_apps_by_ids(app_ids, workspace_id)
|
return service.get_apps_by_ids(app_ids, workspace_id)
|
||||||
|
|
||||||
|
|
||||||
# ==================== 向后兼容的函数接口 ====================
|
|
||||||
|
|
||||||
async def draft_run(
|
|
||||||
db: Session,
|
|
||||||
*,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
message: str,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
|
||||||
workspace_id: Optional[uuid.UUID] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""试运行 Agent(向后兼容接口)"""
|
|
||||||
service = AppService(db)
|
|
||||||
return await service.draft_run(
|
|
||||||
app_id=app_id,
|
|
||||||
message=message,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=user_id,
|
|
||||||
variables=variables,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def draft_run_stream(
|
|
||||||
db: Session,
|
|
||||||
*,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
message: str,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
|
||||||
workspace_id: Optional[uuid.UUID] = None
|
|
||||||
):
|
|
||||||
"""试运行 Agent 流式返回(向后兼容接口)"""
|
|
||||||
service = AppService(db)
|
|
||||||
async for event in service.draft_run_stream(
|
|
||||||
app_id=app_id,
|
|
||||||
message=message,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=user_id,
|
|
||||||
variables=variables,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
def get_app_service(
|
def get_app_service(
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.agent.agent_middleware import AgentMiddleware
|
from app.core.agent.agent_middleware import AgentMiddleware
|
||||||
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository
|
|||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
from app.services import task_service
|
from app.services import task_service
|
||||||
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.langchain_tool_server import Search
|
from app.services.langchain_tool_server import Search
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
@@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel):
|
|||||||
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||||
|
|
||||||
|
|
||||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
|
def create_long_term_memory_tool(
|
||||||
user_rag_memory_id: Optional[str] = None):
|
memory_config: Dict[str, Any],
|
||||||
|
end_user_id: str,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None
|
||||||
|
):
|
||||||
"""创建记忆工具,
|
"""创建记忆工具,
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
memory_config: 记忆配置
|
memory_config: 记忆配置
|
||||||
end_user_id: 用户ID
|
end_user_id: 用户ID
|
||||||
storage_type: 存储类型(可选)
|
storage_type: 存储类型(可选)
|
||||||
|
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
长期记忆工具
|
长期记忆工具
|
||||||
@@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
|||||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 需要检索的问题或关键词
|
kb_config: 知识库配置
|
||||||
|
kb_ids: 知识库ID列表
|
||||||
|
user_id: 用户ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
检索到的相关知识内容
|
检索到的相关知识内容
|
||||||
@@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
|||||||
return knowledge_retrieval_tool
|
return knowledge_retrieval_tool
|
||||||
|
|
||||||
|
|
||||||
class DraftRunService:
|
class AgentRunService:
|
||||||
"""试运行服务类"""
|
"""Agent运行服务类"""
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
"""初始化试运行服务
|
"""Agent运行服务
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
"""
|
"""
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_variables(
|
||||||
|
input_vars: dict | None,
|
||||||
|
variables_config: dict | None
|
||||||
|
) -> dict:
|
||||||
|
input_vars = input_vars or {}
|
||||||
|
for variable in variables_config:
|
||||||
|
if variable.get("required") and variable.get("name") not in input_vars:
|
||||||
|
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
|
||||||
|
return input_vars
|
||||||
|
|
||||||
|
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||||||
|
"""加载工具配置"""
|
||||||
|
if not tools_config:
|
||||||
|
return []
|
||||||
|
tools = []
|
||||||
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
|
if tools_config and isinstance(tools_config, list):
|
||||||
|
for tool_config in tools_config:
|
||||||
|
if tool_config.get("enabled", False):
|
||||||
|
# 根据工具名称查找工具实例
|
||||||
|
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
|
if tool_instance:
|
||||||
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
|
continue
|
||||||
|
# 转换为LangChain工具
|
||||||
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
elif tools_config and isinstance(tools_config, dict):
|
||||||
|
web_search_choice = tools_config.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search and web_search_enable:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def load_skill_config(
|
||||||
|
self,
|
||||||
|
skills_config: dict | None,
|
||||||
|
message: str, tenant_id
|
||||||
|
) -> tuple[list, str]:
|
||||||
|
if not skills_config:
|
||||||
|
return [], ""
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
skill_prompts = ""
|
||||||
|
skill_enable = skills_config.get("enabled", False)
|
||||||
|
if skill_enable:
|
||||||
|
middleware = AgentMiddleware(skills=skills_config)
|
||||||
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
|
if skill_configs:
|
||||||
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||||
|
tool_to_skill_map)
|
||||||
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
|
skill_prompts = AgentMiddleware.get_active_prompts(
|
||||||
|
activated_skill_ids, skill_configs
|
||||||
|
)
|
||||||
|
|
||||||
|
return tools, skill_prompts
|
||||||
|
|
||||||
|
def load_knowledge_retrieval_config(
|
||||||
|
self,
|
||||||
|
knowledge_retrieval_config: dict | None,
|
||||||
|
user_id
|
||||||
|
) -> list:
|
||||||
|
if not knowledge_retrieval_config:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||||
|
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||||
|
if kb_ids:
|
||||||
|
# 创建知识库检索工具
|
||||||
|
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
|
||||||
|
tools.append(kb_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加知识库检索工具",
|
||||||
|
extra={
|
||||||
|
"kb_ids": kb_ids,
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def load_memory_config(
|
||||||
|
self,
|
||||||
|
memory_config: dict | None,
|
||||||
|
user_id,
|
||||||
|
storage_type,
|
||||||
|
user_rag_memory_id
|
||||||
|
) -> tuple[list, bool]:
|
||||||
|
"""加载长期记忆配置"""
|
||||||
|
if not memory_config:
|
||||||
|
return [], False
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
if memory_config.get("enabled"):
|
||||||
|
if user_id:
|
||||||
|
# 创建长期记忆工具
|
||||||
|
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||||
|
user_rag_memory_id)
|
||||||
|
tools.append(memory_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加长期记忆工具",
|
||||||
|
extra={
|
||||||
|
"user_id": user_id,
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tools, bool(memory_config.get("enabled"))
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -270,19 +403,21 @@ class DraftRunService:
|
|||||||
conversation_id: 会话ID(用于多轮对话)
|
conversation_id: 会话ID(用于多轮对话)
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
variables: 自定义变量参数值
|
variables: 自定义变量参数值
|
||||||
|
storage_type: 存储类型(可选)
|
||||||
|
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||||
|
web_search: 是否启用网络搜索(默认True)
|
||||||
|
memory: 是否启用长期记忆(默认True)
|
||||||
|
sub_agent: 是否为子代理调用(默认False)
|
||||||
|
files: 多模态文件列表(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含 AI 回复和元数据的字典
|
Dict: 包含 AI 回复和元数据的字典
|
||||||
"""
|
"""
|
||||||
memory_flag = False
|
|
||||||
|
|
||||||
print('===========', storage_type)
|
|
||||||
|
|
||||||
print(user_id)
|
|
||||||
if variables == None: variables = {}
|
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
tools_config: dict | list | None = agent_config.tools
|
||||||
|
skills_config: dict | None = agent_config.skills
|
||||||
|
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||||
|
memory_config: dict | None = agent_config.memory
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取 API Key 配置
|
# 1. 获取 API Key 配置
|
||||||
@@ -302,112 +437,40 @@ class DraftRunService:
|
|||||||
agent_config=agent_config
|
agent_config=agent_config
|
||||||
)
|
)
|
||||||
|
|
||||||
items_params = variables
|
if sub_agent:
|
||||||
|
variables = self.prepare_variables(variables, agent_config.variables)
|
||||||
|
else:
|
||||||
|
# FIXME: subagent input valid
|
||||||
|
variables = variables or {}
|
||||||
|
|
||||||
system_prompt = render_prompt_message(
|
system_prompt = render_prompt_message(
|
||||||
agent_config.system_prompt, # 修正拼写错误
|
agent_config.system_prompt,
|
||||||
PromptMessageRole.USER,
|
PromptMessageRole.USER,
|
||||||
items_params
|
variables
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 处理系统提示词(支持变量替换)
|
# 3. 处理系统提示词(支持变量替换)
|
||||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||||
print('系统提示词:', system_prompt)
|
|
||||||
|
|
||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
tool_service = ToolService(self.db)
|
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||||
for tool_config in agent_config.tools:
|
tools.extend(skill_tools)
|
||||||
print("+" * 50)
|
if skill_prompts:
|
||||||
print(f"agent_config:{agent_config}")
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
print(f"tool_config:{tool_config}")
|
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||||
if tool_config.get("enabled", False):
|
|
||||||
# 根据工具名称查找工具实例
|
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
|
||||||
if tool_instance:
|
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
|
||||||
continue
|
|
||||||
# 转换为LangChain工具
|
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
|
||||||
tools.append(langchain_tool)
|
|
||||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
|
||||||
web_tools = agent_config.tools
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search:
|
|
||||||
if web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载技能关联的工具
|
|
||||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
|
||||||
skills = agent_config.skills
|
|
||||||
skill_enable = skills.get("enabled", False)
|
|
||||||
if skill_enable:
|
|
||||||
middleware = AgentMiddleware(skills=skills)
|
|
||||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
|
||||||
tools.extend(skill_tools)
|
|
||||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
|
||||||
|
|
||||||
# 应用动态过滤
|
|
||||||
if skill_configs:
|
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
|
||||||
tool_to_skill_map)
|
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
|
||||||
activated_skill_ids, skill_configs
|
|
||||||
)
|
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
|
||||||
if agent_config.knowledge_retrieval:
|
|
||||||
kb_config = agent_config.knowledge_retrieval
|
|
||||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
|
||||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
|
||||||
if kb_ids:
|
|
||||||
# 创建知识库检索工具
|
|
||||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
|
||||||
tools.append(kb_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加知识库检索工具",
|
|
||||||
extra={
|
|
||||||
"kb_ids": kb_ids,
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
memory_tools, memory_flag = self.load_memory_config(
|
||||||
memory_flag = True
|
memory_config, user_id, storage_type, user_rag_memory_id
|
||||||
|
)
|
||||||
memory_config = agent_config.memory
|
tools.extend(memory_tools)
|
||||||
if user_id:
|
|
||||||
# 创建长期记忆工具
|
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
|
||||||
user_rag_memory_id)
|
|
||||||
tools.append(memory_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加长期记忆工具",
|
|
||||||
extra={
|
|
||||||
"user_id": user_id,
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 创建 LangChain Agent
|
# 4. 创建 LangChain Agent
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
@@ -432,7 +495,7 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = []
|
history = []
|
||||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
if memory_config and memory_config.get("enabled"):
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=agent_config.memory.get("max_history", 10)
|
||||||
@@ -482,7 +545,7 @@ class DraftRunService:
|
|||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||||
|
|
||||||
# 9. 保存会话消息
|
# 9. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
@@ -557,16 +620,21 @@ class DraftRunService:
|
|||||||
Yields:
|
Yields:
|
||||||
str: SSE 格式的事件数据
|
str: SSE 格式的事件数据
|
||||||
"""
|
"""
|
||||||
memory_flag = False
|
tools_config: dict | list | None = agent_config.tools
|
||||||
if variables == None: variables = {}
|
skills_config: dict | None = agent_config.skills
|
||||||
|
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
memory_config: dict | None = agent_config.memory
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取 API Key 配置
|
# 1. 获取 API Key 配置
|
||||||
api_key_config = await self._get_api_key(model_config.id)
|
api_key_config = await self._get_api_key(model_config.id)
|
||||||
|
if not sub_agent:
|
||||||
|
variables = self.prepare_variables(variables, agent_config.variables)
|
||||||
|
else:
|
||||||
|
# FIXME: subagent input valid
|
||||||
|
variables = variables or {}
|
||||||
|
|
||||||
# 2. 合并模型参数
|
# 2. 合并模型参数
|
||||||
effective_params = ModelParameterMerger.get_effective_parameters(
|
effective_params = ModelParameterMerger.get_effective_parameters(
|
||||||
@@ -588,95 +656,22 @@ class DraftRunService:
|
|||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
tool_service = ToolService(self.db)
|
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||||
for tool_config in agent_config.tools:
|
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||||
# print("+"*50)
|
tools.extend(skill_tools)
|
||||||
# print(f"agent_config:{agent_config}")
|
if skill_prompts:
|
||||||
# print(f"tool_config:{tool_config}")
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
if tool_config.get("enabled", False):
|
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||||
# 根据工具名称查找工具实例
|
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
|
||||||
if tool_instance:
|
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
|
||||||
continue
|
|
||||||
# 转换为LangChain工具
|
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
|
||||||
tools.append(langchain_tool)
|
|
||||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
|
||||||
web_tools = agent_config.tools
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search:
|
|
||||||
if web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载技能关联的工具
|
|
||||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
|
||||||
skills = agent_config.skills
|
|
||||||
skill_enable = skills.get("enabled", False)
|
|
||||||
if skill_enable:
|
|
||||||
middleware = AgentMiddleware(skills=skills)
|
|
||||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
|
||||||
tools.extend(skill_tools)
|
|
||||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
|
||||||
|
|
||||||
# 应用动态过滤
|
|
||||||
if skill_configs:
|
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
|
||||||
tool_to_skill_map)
|
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
|
||||||
activated_skill_ids, skill_configs
|
|
||||||
)
|
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
|
||||||
if agent_config.knowledge_retrieval:
|
|
||||||
kb_config = agent_config.knowledge_retrieval
|
|
||||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
|
||||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
|
||||||
if kb_ids:
|
|
||||||
# 创建知识库检索工具
|
|
||||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
|
||||||
tools.append(kb_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加知识库检索工具",
|
|
||||||
extra={
|
|
||||||
"kb_ids": kb_ids,
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
|
||||||
memory_flag = True
|
user_rag_memory_id)
|
||||||
memory_config = agent_config.memory
|
tools.extend(memory_tools)
|
||||||
if user_id:
|
|
||||||
# 创建长期记忆工具
|
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
|
||||||
user_rag_memory_id)
|
|
||||||
tools.append(memory_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加长期记忆工具",
|
|
||||||
extra={
|
|
||||||
"user_id": user_id,
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 创建 LangChain Agent
|
# 4. 创建 LangChain Agent
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
@@ -702,10 +697,10 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = []
|
history = []
|
||||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
if memory_config and memory_config.get("enabled"):
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=memory_config.get("max_history", 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 处理多模态文件
|
# 6. 处理多模态文件
|
||||||
@@ -763,7 +758,7 @@ class DraftRunService:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# 10. 保存会话消息
|
# 10. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
@@ -969,7 +964,6 @@ class DraftRunService:
|
|||||||
List[Dict]: 历史消息列表
|
List[Dict]: 历史消息列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.services.conversation_service import ConversationService
|
|
||||||
|
|
||||||
conversation_service = ConversationService(self.db)
|
conversation_service = ConversationService(self.db)
|
||||||
history = conversation_service.get_conversation_history(
|
history = conversation_service.get_conversation_history(
|
||||||
@@ -1489,6 +1483,15 @@ class DraftRunService:
|
|||||||
"conversation_id": returned_conversation_id,
|
"conversation_id": returned_conversation_id,
|
||||||
"content": chunk
|
"content": chunk
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
if event_type == "error" and event_data:
|
||||||
|
await event_queue.put(self._format_sse_event("model_error", {
|
||||||
|
"model_index": idx,
|
||||||
|
"model_config_id": model_config_id,
|
||||||
|
"label": model_label,
|
||||||
|
"conversation_id": returned_conversation_id,
|
||||||
|
"error": event_data.get("error", "未知错误")
|
||||||
|
}))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"解析流式事件失败: {e}")
|
logger.warning(f"解析流式事件失败: {e}")
|
||||||
finally:
|
finally:
|
||||||
@@ -1673,41 +1676,3 @@ class DraftRunService:
|
|||||||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def draft_run(
|
|
||||||
db: Session,
|
|
||||||
*,
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
message: str,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
kb_ids: Optional[List[str]] = None,
|
|
||||||
similarity_threshold: float = 0.7,
|
|
||||||
top_k: int = 3
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""试运行 Agent(便捷函数)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
agent_config: Agent 配置
|
|
||||||
model_config: 模型配置
|
|
||||||
message: 用户消息
|
|
||||||
user_id: 用户ID
|
|
||||||
kb_ids: 知识库ID列表
|
|
||||||
similarity_threshold: 相似度阈值
|
|
||||||
top_k: 检索返回的文档数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 包含 AI 回复和元数据的字典
|
|
||||||
"""
|
|
||||||
service = DraftRunService(db)
|
|
||||||
return await service.run(
|
|
||||||
agent_config=agent_config,
|
|
||||||
model_config=model_config,
|
|
||||||
message=message,
|
|
||||||
user_id=user_id,
|
|
||||||
kb_ids=kb_ids,
|
|
||||||
similarity_threshold=similarity_threshold,
|
|
||||||
top_k=top_k
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ load_dotenv()
|
|||||||
|
|
||||||
# 读取web_search环境变量
|
# 读取web_search环境变量
|
||||||
web_search_value = os.getenv('web_search')
|
web_search_value = os.getenv('web_search')
|
||||||
|
|
||||||
|
|
||||||
def Search(query):
|
def Search(query):
|
||||||
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
||||||
api_key = web_search_value
|
api_key = web_search_value
|
||||||
@@ -18,23 +20,24 @@ def Search(query):
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": query
|
"content": query
|
||||||
}
|
}
|
||||||
], #搜索输入
|
], # 搜索输入
|
||||||
"edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
"edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
||||||
"search_source": "baidu_search_v2", #使用的搜索引擎版本
|
"search_source": "baidu_search_v2", # 使用的搜索引擎版本
|
||||||
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
"resource_type_filter": [{"type": "web", "top_k": 20}],
|
||||||
|
# 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
||||||
"search_filter": {
|
"search_filter": {
|
||||||
"range": {
|
"range": {
|
||||||
"page_time": {
|
"page_time": {
|
||||||
"gte": "now-1w/d", #时间查询参数,大于或等于
|
"gte": "now-1w/d", # 时间查询参数,大于或等于
|
||||||
"lt": "now/d", #时间查询参数,小于
|
"lt": "now/d", # 时间查询参数,小于
|
||||||
"gt": "", #时间查询参数,大于
|
"gt": "", # 时间查询参数,大于
|
||||||
"lte": "" #时间查询参数,小于或等于
|
"lte": "" # 时间查询参数,小于或等于
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
|
"block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表
|
||||||
"search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
"search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
||||||
"enable_full_content":True #是否输出网页完整原文
|
"enable_full_content": True # 是否输出网页完整原文
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -42,10 +45,10 @@ def Search(query):
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
|
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
|
||||||
content=[]
|
content = []
|
||||||
for i in response['references']:
|
for i in response['references']:
|
||||||
title=i['title']
|
title = i['title']
|
||||||
snippet=i['snippet']
|
snippet = i['snippet']
|
||||||
content.append(title+';'+snippet)
|
content.append(title + ';' + snippet)
|
||||||
content='。'.join(content)
|
content = '。'.join(content)
|
||||||
return content
|
return content
|
||||||
@@ -123,11 +123,14 @@ class MultiAgentOrchestrator:
|
|||||||
user_id: 用户 ID
|
user_id: 用户 ID
|
||||||
variables: 变量参数
|
variables: 变量参数
|
||||||
use_llm_routing: 是否使用 LLM 路由
|
use_llm_routing: 是否使用 LLM 路由
|
||||||
|
web_search: 是否启用网络搜索
|
||||||
|
memory: 是否启用记忆功能
|
||||||
|
storage_type: 存储类型
|
||||||
|
user_rag_memory_id: 用户 RAG 记忆 ID
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
SSE 格式的事件流
|
SSE 格式的事件流
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -200,7 +203,8 @@ class MultiAgentOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"多 Agent 任务执行失败(流式)",
|
"多 Agent 任务执行失败(流式)",
|
||||||
extra={"error": str(e), "mode": self._normalized_mode}
|
extra={"error": str(e), "mode": self._normalized_mode},
|
||||||
|
exc_info=True
|
||||||
)
|
)
|
||||||
# 发送错误事件
|
# 发送错误事件
|
||||||
yield self._format_sse_event("error", {
|
yield self._format_sse_event("error", {
|
||||||
@@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator:
|
|||||||
Yields:
|
Yields:
|
||||||
SSE 格式的事件流
|
SSE 格式的事件流
|
||||||
"""
|
"""
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
# 获取模型配置
|
# 获取模型配置
|
||||||
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
||||||
@@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 流式执行 Agent
|
# 流式执行 Agent
|
||||||
draft_service = DraftRunService(self.db)
|
draft_service = AgentRunService(self.db)
|
||||||
async for event in draft_service.run_stream(
|
async for event in draft_service.run_stream(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator:
|
|||||||
Returns:
|
Returns:
|
||||||
执行结果
|
执行结果
|
||||||
"""
|
"""
|
||||||
from app.services.draft_run_service import DraftRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
# 获取模型配置
|
# 获取模型配置
|
||||||
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
||||||
@@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 执行 Agent
|
# 执行 Agent
|
||||||
draft_service = DraftRunService(self.db)
|
draft_service = AgentRunService(self.db)
|
||||||
result = await draft_service.run(
|
result = await draft_service.run(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator:
|
|||||||
self.memory = config_data.get("memory")
|
self.memory = config_data.get("memory")
|
||||||
self.variables = config_data.get("variables", [])
|
self.variables = config_data.get("variables", [])
|
||||||
self.tools = config_data.get("tools", {})
|
self.tools = config_data.get("tools", {})
|
||||||
|
self.skills = config_data.get("skills", {})
|
||||||
self.default_model_config_id = release.default_model_config_id
|
self.default_model_config_id = release.default_model_config_id
|
||||||
|
|
||||||
return AgentConfigProxy(release, app, config_data)
|
return AgentConfigProxy(release, app, config_data)
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class SkillService:
|
|||||||
if skill and skill.is_active:
|
if skill and skill.is_active:
|
||||||
# 加载技能关联的工具
|
# 加载技能关联的工具
|
||||||
for tool_config in skill.tools:
|
for tool_config in skill.tools:
|
||||||
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
if tool:
|
if tool:
|
||||||
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ class ToolService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取工具实例
|
# 获取工具实例
|
||||||
tool = self._get_tool_instance(tool_id, tenant_id)
|
tool = self.get_tool_instance(tool_id, tenant_id)
|
||||||
if not tool:
|
if not tool:
|
||||||
return ToolResult.error_result(
|
return ToolResult.error_result(
|
||||||
error=f"工具不存在: {tool_id}",
|
error=f"工具不存在: {tool_id}",
|
||||||
@@ -335,7 +335,7 @@ class ToolService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 获取工具实例
|
# 获取工具实例
|
||||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -792,7 +792,7 @@ class ToolService:
|
|||||||
"""获取工具配置"""
|
"""获取工具配置"""
|
||||||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||||
|
|
||||||
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||||
"""获取工具实例"""
|
"""获取工具实例"""
|
||||||
if tool_id in self._tool_cache:
|
if tool_id in self._tool_cache:
|
||||||
return self._tool_cache[tool_id]
|
return self._tool_cache[tool_id]
|
||||||
@@ -1416,7 +1416,7 @@ class ToolService:
|
|||||||
"""测试内置工具连接"""
|
"""测试内置工具连接"""
|
||||||
try:
|
try:
|
||||||
# 获取工具实例
|
# 获取工具实例
|
||||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
return {"success": False, "message": "无法创建工具实例"}
|
return {"success": False, "message": "无法创建工具实例"}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user