diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 653f616c..cdf94345 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -396,10 +396,10 @@ async def draft_run( from app.models import AgentConfig, ModelConfig from sqlalchemy import select from app.core.exceptions import BusinessException - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService service = AppService(db) - draft_service = DraftRunService(db) + draft_service = AgentRunService(db) # 1. 验证应用 app = service._get_app_or_404(app_id) @@ -484,8 +484,8 @@ async def draft_run( } ) - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_cfg, model_config=model_config, @@ -789,8 +789,8 @@ async def draft_run_compare( # 流式返回 if payload.stream: async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) async for event in draft_service.run_compare_stream( agent_config=agent_cfg, models=model_configs, @@ -820,8 +820,8 @@ async def draft_run_compare( ) # 非流式返回 - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run_compare( agent_config=agent_cfg, models=model_configs, diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 5d4dbd10..dba6717d 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -21,6 +21,7 @@ from pydantic import BaseModel, Field T = TypeVar("T") + class RedBearModelConfig(BaseModel): """模型配置基类""" model_name: str @@ -32,17 +33,18 @@ class RedBearModelConfig(BaseModel): timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2"))) - concurrency: int = 5 # 并发限流 + concurrency: int = 5 # 并发限流 extra_params: Dict[str, Any] = {} + class RedBearModelFactory: """模型工厂类""" - + @classmethod def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() - + # 打印供应商信息用于调试 from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -87,7 +89,7 @@ class RedBearModelFactory: "timeout": timeout_config, "max_retries": config.max_retries, **config.extra_params - } + } elif provider == ModelProvider.DASHSCOPE: # DashScope (通义千问) 使用自己的参数格式 # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 @@ -104,7 +106,7 @@ class RedBearModelFactory: # region 从 base_url 或 extra_params 获取 from botocore.config import Config as BotoConfig from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id - + max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50")) max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2")) # Configure with increased connection pool @@ -112,16 +114,16 @@ class RedBearModelFactory: max_pool_connections=max_pool_connections, retries={'max_attempts': max_retries, 'mode': 'adaptive'} ) - + # 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID) model_id = normalize_bedrock_model_id(config.model_name) - + params = { "model_id": model_id, "config": boto_config, **config.extra_params } - + # 解析 API key (格式: access_key_id:secret_access_key) if config.api_key and ":" in config.api_key: access_key_id, secret_access_key = config.api_key.split(":", 1) @@ -129,51 +131,52 @@ class RedBearModelFactory: params["aws_secret_access_key"] = secret_access_key elif config.api_key: params["aws_access_key_id"] = config.api_key - + # 设置 region if config.base_url: params["region_name"] = config.base_url elif "region_name" not in params: params["region_name"] = "us-east-1" # 默认区域 - + return params else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - + @classmethod def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: - return { + return { "model": config.model_name, # "base_url": config.base_url, "jina_api_key": config.api_key, **config.extra_params - } + } else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) -def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]: + +def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]: """根据模型提供商获取对应的模型类""" provider = config.provider.lower() - + # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: from langchain_openai import ChatOpenAI return ChatOpenAI - - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if type == ModelType.LLM: from langchain_openai import OpenAI - return OpenAI + return OpenAI elif type == ModelType.CHAT: from langchain_openai import ChatOpenAI return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: from langchain_community.chat_models import ChatTongyi return ChatTongyi - elif provider == ModelProvider.OLLAMA: + elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaLLM return OllamaLLM elif provider == ModelProvider.BEDROCK: @@ -183,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType. else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_embedding_class(provider: str) -> type[Embeddings]: """根据模型提供商获取对应的模型类""" provider = provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings + return OpenAIEmbeddings elif provider == ModelProvider.DASHSCOPE: from langchain_community.embeddings import DashScopeEmbeddings - return DashScopeEmbeddings + return DashScopeEmbeddings elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings @@ -201,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]: else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_rerank_class(provider: str): """根据模型提供商获取对应的模型类""" - provider = provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + provider = provider.lower() + if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_community.document_compressors import JinaRerank - return JinaRerank - # elif provider == ModelProvider.OLLAMA: + return JinaRerank + # elif provider == ModelProvider.OLLAMA: # from langchain_ollama import OllamaEmbeddings # return OllamaEmbeddings else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) \ No newline at end of file + raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 98d8bb75..3fbbbdbc 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.variable.base_variable import VariableType from app.db import get_db from app.models import AppRelease -from app.services.draft_run_service import DraftRunService +from app.services.draft_run_service import AgentRunService logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ class AgentNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} - def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]: """准备 Agent(公共逻辑) Args: @@ -65,7 +65,7 @@ class AgentNode(BaseNode): if not release: raise ValueError(f"Agent 不存在: {agent_id}") - draft_service = DraftRunService(db) + draft_service = AgentRunService(db) return draft_service, release, message diff --git a/api/app/schemas/api_key_schema.py b/api/app/schemas/api_key_schema.py index d19cf061..323c1a69 100644 --- a/api/app/schemas/api_key_schema.py +++ b/api/app/schemas/api_key_schema.py @@ -155,8 +155,7 @@ class ApiKey(BaseModel): return datetime.datetime.now() > self.expires_at @field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel): avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") @field_serializer('last_used_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel): created_at: datetime.datetime @field_serializer('created_at') - @classmethod - def serialize_datetime(cls, v: datetime.datetime) -> int: + def serialize_datetime(self, v: datetime.datetime) -> int: """将datetime转换为时间戳""" return datetime_to_timestamp(v) diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index 8fba2929..3573e87c 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel): class MultiAgentConfigCreate(BaseModel): """创建多 Agent 配置""" master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") orchestration_mode: str = Field( default="collaboration", pattern="^(collaboration|supervisor)$", description="协作模式:collaboration(协作)| supervisor(监督)" ) sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") - routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") + routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") aggregation_strategy: str = Field( default="merge", @@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel): class MultiAgentConfigUpdate(BaseModel): """更新多 Agent 配置""" master_agent_id: Optional[uuid.UUID] = None - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field( None, diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 3ca7bddd..a4768b51 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -263,8 +263,8 @@ def create_agent_invocation_tool( try: # 9. 调用 Agent - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_config, diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index e6ac227b..5430d2f9 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,25 +10,24 @@ from sqlalchemy.orm import Session from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent -from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.db import get_db, get_db_context -from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig -from app.schemas import DraftRunRequest -from app.schemas.app_schema import FileInput -from app.services.tool_service import ToolService -from app.repositories.tool_repository import ToolRepository from app.db import get_db from app.models import MultiAgentConfig, AgentConfig +from app.models import WorkflowConfig +from app.repositories.tool_repository import ToolRepository +from app.schemas import DraftRunRequest +from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool +from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ + AgentRunService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator -from app.services.workflow_service import WorkflowService from app.services.multimodal_service import MultimodalService +from app.services.tool_service import ToolService +from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -39,6 +38,8 @@ class AppChatService: def __init__(self, db: Session): self.db = db self.conversation_service = ConversationService(db) + self.agent_service = AgentRunService(db) + self.workflow_service = WorkflowService(db) async def agnet_chat( self, @@ -55,12 +56,10 @@ class AppChatService: files: Optional[List[FileInput]] = None # 新增:多模态文件 ) -> Dict[str, Any]: """聊天(非流式)""" - start_time = time.time() config_id = None - if variables is None: - variables = {} + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id @@ -79,74 +78,20 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - # 从配置中获取启用的工具 - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) memory_flag = False - if memory == True: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + if memory: + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -246,10 +191,9 @@ class AppChatService: try: start_time = time.time() config_id = None + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - if variables is None: - variables = {} - + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) @@ -267,73 +211,22 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) # 添加长期记忆工具 memory_flag = False if memory: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -372,9 +265,6 @@ class AppChatService: processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - # 流式调用 Agent(支持多模态) full_content = "" total_tokens = 0 @@ -418,7 +308,7 @@ class AppChatService: ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} + end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" logger.info( @@ -437,7 +327,7 @@ class AppChatService: except Exception as e: logger.error(f"流式聊天失败: {str(e)}", exc_info=True) # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" async def multi_agent_chat( self, @@ -491,10 +381,10 @@ class AppChatService: "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }) + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) @@ -524,8 +414,6 @@ class AppChatService: """多 Agent 聊天(流式)""" start_time = time.time() - actual_config_id = None - config_id = actual_config_id if variables is None: variables = {} @@ -631,7 +519,6 @@ class AppChatService: user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -639,7 +526,7 @@ class AppChatService: stream=True, user_id=user_id ) - return await workflow_service.run( + return await self.workflow_service.run( app_id=app_id, payload=payload, config=config, @@ -666,7 +553,6 @@ class AppChatService: ) -> AsyncGenerator[dict, None]: """聊天(流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -675,7 +561,7 @@ class AppChatService: user_id=user_id, files=files ) - async for event in workflow_service.run_stream( + async for event in self.workflow_service.run_stream( app_id=app_id, payload=payload, config=config, diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index c5919af9..a248f869 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1791,372 +1791,6 @@ class AppService: return shares - # ==================== 试运行功能 ==================== - - async def draft_run( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> Dict[str, Any]: - """试运行 Agent(使用当前草稿配置) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - Dict: 包含 AI 回复和元数据的字典 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用试运行服务 - logger.debug( - "准备调用试运行服务", - extra={ - "app_id": str(app_id), - "model": model_config.name, - "has_conversation_id": bool(conversation_id), - "has_variables": bool(variables) - } - ) - - draft_service = DraftRunService(self.db) - result = await draft_service.run( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ) - - logger.debug( - "试运行服务返回结果", - extra={ - "result_type": str(type(result)), - "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict", - "has_message": "message" in result if isinstance(result, dict) else False, - "has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False - } - ) - - logger.info( - "试运行完成", - extra={ - "app_id": str(app_id), - "elapsed_time": result.get("elapsed_time"), - "model": model_config.name - } - ) - - return result - - async def draft_run_stream( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ): - """试运行 Agent(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Yields: - str: SSE 格式的事件数据 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用流式试运行服务 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ): - yield event - - # ==================== 多模型对比试运行 ==================== - - async def draft_run_compare( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ) -> Dict[str, Any]: - """多模型对比试运行 - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Returns: - Dict: 对比结果 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models), - "parallel": parallel - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的对比方法 - draft_service = DraftRunService(self.db) - result = await draft_service.run_compare( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ) - - logger.info( - "多模型对比完成", - extra={ - "app_id": str(app_id), - "successful": result["successful_count"], - "failed": result["failed_count"] - } - ) - - return result - - async def draft_run_compare_stream( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ): - """多模型对比试运行(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - timeout: 超时时间(秒) - - Yields: - str: SSE 格式的事件数据 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比流式试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的流式对比方法 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ): - yield event - - logger.info( - "多模型对比流式完成", - extra={"app_id": str(app_id)} - ) - - # ==================== 向后兼容的函数接口 ==================== # 保留函数接口以兼容现有代码,但内部使用服务类 @@ -2278,53 +1912,6 @@ def get_apps_by_ids( return service.get_apps_by_ids(app_ids, workspace_id) -# ==================== 向后兼容的函数接口 ==================== - -async def draft_run( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -) -> Dict[str, Any]: - """试运行 Agent(向后兼容接口)""" - service = AppService(db) - return await service.draft_run( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ) - - -async def draft_run_stream( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -): - """试运行 Agent 流式返回(向后兼容接口)""" - service = AppService(db) - async for event in service.draft_run_stream( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ): - yield event - - # ==================== 依赖注入函数 ==================== def get_app_service( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 693f1a26..0cf68be2 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -17,6 +17,7 @@ from sqlalchemy.orm import Session from app.celery_app import celery_app from app.core.agent.agent_middleware import AgentMiddleware +from app.core.agent.langchain_agent import LangChainAgent from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger @@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service +from app.services.conversation_service import ConversationService from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger @@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel): description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") -def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None): +def create_long_term_memory_tool( + memory_config: Dict[str, Any], + end_user_id: str, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None +): """创建记忆工具, @@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str memory_config: 记忆配置 end_user_id: 用户ID storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) Returns: 长期记忆工具 @@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: - query: 需要检索的问题或关键词 + kb_config: 知识库配置 + kb_ids: 知识库ID列表 + user_id: 用户ID Returns: 检索到的相关知识内容 @@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): return knowledge_retrieval_tool -class DraftRunService: - """试运行服务类""" +class AgentRunService: + """Agent运行服务类""" def __init__(self, db: Session): - """初始化试运行服务 + """Agent运行服务 Args: db: 数据库会话 """ self.db = db + @staticmethod + def prepare_variables( + input_vars: dict | None, + variables_config: dict | None + ) -> dict: + input_vars = input_vars or {} + for variable in variables_config: + if variable.get("required") and variable.get("name") not in input_vars: + raise ValueError(f"The required parameter '{variable.get('name')}' was not provided") + return input_vars + + def load_tools_config(self, tools_config, web_search, tenant_id) -> list: + """加载工具配置""" + if not tools_config: + return [] + tools = [] + tool_service = ToolService(self.db) + + if tools_config and isinstance(tools_config, list): + for tool_config in tools_config: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif tools_config and isinstance(tools_config, dict): + web_search_choice = tools_config.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search and web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + return tools + + def load_skill_config( + self, + skills_config: dict | None, + message: str, tenant_id + ) -> tuple[list, str]: + if not skills_config: + return [], "" + + tools = [] + skill_prompts = "" + skill_enable = skills_config.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills_config) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + skill_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + + return tools, skill_prompts + + def load_knowledge_retrieval_config( + self, + knowledge_retrieval_config: dict | None, + user_id + ) -> list: + if not knowledge_retrieval_config: + return [] + + tools = [] + knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) + kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + if kb_ids: + # 创建知识库检索工具 + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + tools.append(kb_tool) + + logger.debug( + "已添加知识库检索工具", + extra={ + "kb_ids": kb_ids, + "tool_count": len(tools) + } + ) + return tools + + def load_memory_config( + self, + memory_config: dict | None, + user_id, + storage_type, + user_rag_memory_id + ) -> tuple[list, bool]: + """加载长期记忆配置""" + if not memory_config: + return [], False + + tools = [] + if memory_config.get("enabled"): + if user_id: + # 创建长期记忆工具 + memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.append(memory_tool) + + logger.debug( + "已添加长期记忆工具", + extra={ + "user_id": user_id, + "tool_count": len(tools) + } + ) + return tools, bool(memory_config.get("enabled")) + async def run( self, *, @@ -270,19 +403,21 @@ class DraftRunService: conversation_id: 会话ID(用于多轮对话) user_id: 用户ID variables: 自定义变量参数值 + storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) + web_search: 是否启用网络搜索(默认True) + memory: 是否启用长期记忆(默认True) + sub_agent: 是否为子代理调用(默认False) + files: 多模态文件列表(可选) Returns: Dict: 包含 AI 回复和元数据的字典 """ - memory_flag = False - - print('===========', storage_type) - - print(user_id) - if variables == None: variables = {} - from app.core.agent.langchain_agent import LangChainAgent - start_time = time.time() + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory try: # 1. 获取 API Key 配置 @@ -302,112 +437,40 @@ class DraftRunService: agent_config=agent_config ) - items_params = variables + if sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} + system_prompt = render_prompt_message( - agent_config.system_prompt, # 修正拼写错误 + agent_config.system_prompt, PromptMessageRole.USER, - items_params + variables ) # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - print('系统提示词:', system_prompt) # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_config in agent_config.tools: - print("+" * 50) - print(f"agent_config:{agent_config}") - print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) - + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config( + memory_config, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -432,7 +495,7 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, max_history=agent_config.memory.get("max_history", 10) @@ -482,7 +545,7 @@ class DraftRunService: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) # 9. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -557,16 +620,21 @@ class DraftRunService: Yields: str: SSE 格式的事件数据 """ - memory_flag = False - if variables == None: variables = {} - - from app.core.agent.langchain_agent import LangChainAgent + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory start_time = time.time() try: # 1. 获取 API Key 配置 api_key_config = await self._get_api_key(model_config.id) + if not sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} # 2. 合并模型参数 effective_params = ModelParameterMerger.get_effective_parameters( @@ -588,95 +656,22 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - for tool_config in agent_config.tools: - # print("+"*50) - # print(f"agent_config:{agent_config}") - # print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -702,10 +697,10 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) + max_history=memory_config.get("max_history", 10) ) # 6. 处理多模态文件 @@ -763,7 +758,7 @@ class DraftRunService: }) # 10. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -969,7 +964,6 @@ class DraftRunService: List[Dict]: 历史消息列表 """ try: - from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( @@ -1489,6 +1483,15 @@ class DraftRunService: "conversation_id": returned_conversation_id, "content": chunk })) + + if event_type == "error" and event_data: + await event_queue.put(self._format_sse_event("model_error", { + "model_index": idx, + "model_config_id": model_config_id, + "label": model_label, + "conversation_id": returned_conversation_id, + "error": event_data.get("error", "未知错误") + })) except Exception as e: logger.warning(f"解析流式事件失败: {e}") finally: @@ -1673,41 +1676,3 @@ class DraftRunService: "total_time": sum(r.get("elapsed_time", 0) for r in results) } ) - - -async def draft_run( - db: Session, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - user_id: Optional[str] = None, - kb_ids: Optional[List[str]] = None, - similarity_threshold: float = 0.7, - top_k: int = 3 -) -> Dict[str, Any]: - """试运行 Agent(便捷函数) - - Args: - db: 数据库会话 - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - user_id: 用户ID - kb_ids: 知识库ID列表 - similarity_threshold: 相似度阈值 - top_k: 检索返回的文档数量 - - Returns: - Dict: 包含 AI 回复和元数据的字典 - """ - service = DraftRunService(db) - return await service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - user_id=user_id, - kb_ids=kb_ids, - similarity_threshold=similarity_threshold, - top_k=top_k - ) diff --git a/api/app/services/langchain_tool_server.py b/api/app/services/langchain_tool_server.py index f44e4cdc..2c151956 100644 --- a/api/app/services/langchain_tool_server.py +++ b/api/app/services/langchain_tool_server.py @@ -9,6 +9,8 @@ load_dotenv() # 读取web_search环境变量 web_search_value = os.getenv('web_search') + + def Search(query): url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" api_key = web_search_value @@ -18,23 +20,24 @@ def Search(query): "role": "user", "content": query } - ], #搜索输入 - "edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 - "search_source": "baidu_search_v2", #使用的搜索引擎版本 - "resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 + ], # 搜索输入 + "edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 + "search_source": "baidu_search_v2", # 使用的搜索引擎版本 + "resource_type_filter": [{"type": "web", "top_k": 20}], + # 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 "search_filter": { "range": { "page_time": { - "gte": "now-1w/d", #时间查询参数,大于或等于 - "lt": "now/d", #时间查询参数,小于 - "gt": "", #时间查询参数,大于 - "lte": "" #时间查询参数,小于或等于 + "gte": "now-1w/d", # 时间查询参数,大于或等于 + "lt": "now/d", # 时间查询参数,小于 + "gt": "", # 时间查询参数,大于 + "lte": "" # 时间查询参数,小于或等于 } } }, - "block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表 - "search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year - "enable_full_content":True #是否输出网页完整原文 + "block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表 + "search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year + "enable_full_content": True # 是否输出网页完整原文 }, ensure_ascii=False) headers = { 'Content-Type': 'application/json', @@ -42,10 +45,10 @@ def Search(query): } response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json() - content=[] + content = [] for i in response['references']: - title=i['title'] - snippet=i['snippet'] - content.append(title+';'+snippet) - content='。'.join(content) - return content \ No newline at end of file + title = i['title'] + snippet = i['snippet'] + content.append(title + ';' + snippet) + content = '。'.join(content) + return content diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 650f639b..f42ee95a 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -123,11 +123,14 @@ class MultiAgentOrchestrator: user_id: 用户 ID variables: 变量参数 use_llm_routing: 是否使用 LLM 路由 + web_search: 是否启用网络搜索 + memory: 是否启用记忆功能 + storage_type: 存储类型 + user_rag_memory_id: 用户 RAG 记忆 ID Yields: SSE 格式的事件流 """ - import json start_time = time.time() @@ -200,7 +203,8 @@ class MultiAgentOrchestrator: except Exception as e: logger.error( "多 Agent 任务执行失败(流式)", - extra={"error": str(e), "mode": self._normalized_mode} + extra={"error": str(e), "mode": self._normalized_mode}, + exc_info=True ) # 发送错误事件 yield self._format_sse_event("error", { @@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator: Yields: SSE 格式的事件流 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator: ) # 流式执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) async for event in draft_service.run_stream( agent_config=agent_config, model_config=model_config, @@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator: Returns: 执行结果 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator: ) # 执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) result = await draft_service.run( agent_config=agent_config, model_config=model_config, @@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator: self.memory = config_data.get("memory") self.variables = config_data.get("variables", []) self.tools = config_data.get("tools", {}) + self.skills = config_data.get("skills", {}) self.default_model_config_id = release.default_model_config_id return AgentConfigProxy(release, app, config_data) diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py index 5eb80795..0b7de6cf 100644 --- a/api/app/services/skill_service.py +++ b/api/app/services/skill_service.py @@ -121,7 +121,7 @@ class SkillService: if skill and skill.is_active: # 加载技能关联的工具 for tool_config in skill.tools: - tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool: langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 2bb96e53..d2400ded 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -209,7 +209,7 @@ class ToolService: try: # 获取工具实例 - tool = self._get_tool_instance(tool_id, tenant_id) + tool = self.get_tool_instance(tool_id, tenant_id) if not tool: return ToolResult.error_result( error=f"工具不存在: {tool_id}", @@ -335,7 +335,7 @@ class ToolService: return [] # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return [] @@ -792,7 +792,7 @@ class ToolService: """获取工具配置""" return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) - def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: + def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: """获取工具实例""" if tool_id in self._tool_cache: return self._tool_cache[tool_id] @@ -1416,7 +1416,7 @@ class ToolService: """测试内置工具连接""" try: # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return {"success": False, "message": "无法创建工具实例"}