From 71abd16ae76e1bd66b85d73f20a996b5feaeee29 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 4 Feb 2026 18:06:29 +0800 Subject: [PATCH] fix(skills): configuration modification --- api/app/core/agent/agent_middleware.py | 23 +++-- api/app/models/agent_app_config_model.py | 4 +- api/app/schemas/app_schema.py | 12 ++- api/app/schemas/skill_schema.py | 11 +- api/app/services/agent_config_converter.py | 16 +-- api/app/services/agent_config_helper.py | 4 +- api/app/services/app_chat_service.py | 113 ++++++++++++++------- api/app/services/app_service.py | 6 +- api/app/services/draft_run_service.py | 59 ++++++----- api/app/services/skill_service.py | 24 +++++ api/app/utils/app_config_utils.py | 3 +- 11 files changed, 184 insertions(+), 91 deletions(-) diff --git a/api/app/core/agent/agent_middleware.py b/api/app/core/agent/agent_middleware.py index ef5f7847..735423c9 100644 --- a/api/app/core/agent/agent_middleware.py +++ b/api/app/core/agent/agent_middleware.py @@ -10,14 +10,17 @@ from app.repositories.skill_repository import SkillRepository class AgentMiddleware: """Agent 中间件 - 用于动态过滤和加载技能""" - def __init__(self, skill_ids: Optional[List[str]] = None): + def __init__(self, skills: Optional[dict] = None): """ 初始化中间件 Args: - skill_ids: 技能ID列表 + skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]} """ - self.skill_ids = skill_ids or [] + self.skills = skills or {} + self.enabled = self.skills.get('enabled', False) + self.all_skills = self.skills.get('all_skills', False) + self.skill_ids = self.skills.get('skill_ids', []) @staticmethod def filter_tools( @@ -95,9 +98,17 @@ class AgentMiddleware: # base_tools 不属于任何 skill,不加入映射 skill_configs = {} + skill_ids_to_load = [] - if self.skill_ids: - for skill_id in self.skill_ids: + # 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能 + if self.enabled and self.all_skills: + skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000) + skill_ids_to_load = [str(skill.id) for skill in skills] + elif self.enabled and self.skill_ids: + skill_ids_to_load = self.skill_ids + + if skill_ids_to_load: + for skill_id in skill_ids_to_load: try: skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id) if skill and skill.is_active: @@ -110,7 +121,7 @@ class AgentMiddleware: continue # 加载技能工具并获取映射关系 - skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id) + skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id) # 只添加不冲突的 skill_tools for tool in skill_tools: diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 7ed70728..cc2e0686 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -29,8 +29,8 @@ class AgentConfig(Base): knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置") memory = Column(JSON, nullable=True, comment="记忆配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置") - tools = Column(JSON, default=dict, nullable=True, comment="工具配置") - skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表") + tools = Column(JSON, default=list, nullable=True, comment="工具配置") + skills = Column(JSON, default=dict, nullable=True, comment="技能配置") # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index bcfeca57..2ad27ace 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -82,6 +82,12 @@ class ToolConfig(BaseModel): tool_id: Optional[str] = Field(default=None, description="工具ID") operation: Optional[str] = Field(default=None, description="工具特定配置") +class SkillConfig(BaseModel): + """技能配置""" + enabled: bool = Field(default=True, description="是否启用该技能") + skill_ids: Optional[list[str]] = Field(default=list, description="技能ID列表") + all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能") + class ToolOldConfig(BaseModel): """工具配置""" @@ -157,7 +163,7 @@ class AgentConfigCreate(BaseModel): ) # 技能配置 - skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表") + skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") class AppCreate(BaseModel): @@ -212,7 +218,7 @@ class AgentConfigUpdate(BaseModel): tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表") # 技能配置 - skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表") + skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") # ---------- Output Schemas ---------- @@ -272,7 +278,7 @@ class AgentConfig(BaseModel): # 工具配置 tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] - skill_ids: Optional[List[str]] = [] + skills: Optional[SkillConfig] = {} is_active: bool created_at: datetime.datetime diff --git a/api/app/schemas/skill_schema.py b/api/app/schemas/skill_schema.py index 27f16b99..f002308e 100644 --- a/api/app/schemas/skill_schema.py +++ b/api/app/schemas/skill_schema.py @@ -1,5 +1,5 @@ """Skill Schema 定义""" -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, field_serializer import uuid from datetime import datetime @@ -32,10 +32,17 @@ class SkillUpdate(BaseModel): is_public: Optional[bool] = None -class Skill(SkillBase): +class Skill(BaseModel): """Skill 响应 Schema""" id: uuid.UUID tenant_id: uuid.UUID + name: str + description: Optional[str] = None + tools: Union[List[Dict[str, Any]], List[Dict[str, str]]] = Field(default_factory=list, description="工具列表,可以是简单格式或包含工具详情") + config: Dict[str, Any] = Field(default_factory=dict) + prompt: Optional[str] = None + is_active: bool = True + is_public: bool = False created_at: datetime updated_at: datetime diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index ba76e299..4b5cc3b6 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -9,7 +9,7 @@ from app.schemas.app_schema import ( VariableDefinition, ToolConfig, AgentConfigCreate, - AgentConfigUpdate, ToolOldConfig, + AgentConfigUpdate, ToolOldConfig, SkillConfig, ) @@ -49,8 +49,8 @@ class AgentConfigConverter: if hasattr(config, 'tools') and config.tools: result["tools"] = [tool.model_dump() for tool in config.tools] - if hasattr(config, "skill_ids") and config.skill_ids: - result["skill_ids"] = [skill for skill in config.skill_ids] + if hasattr(config, "skills") and config.skills: + result["skills"] = config.skills.model_dump() return result @@ -61,7 +61,7 @@ class AgentConfigConverter: memory: Optional[Dict[str, Any]], variables: Optional[list], tools: Optional[Union[list, Dict[str, Any]]], - skill_ids: Optional[list] + skills: Optional[dict] ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -72,7 +72,7 @@ class AgentConfigConverter: memory: 记忆配置 variables: 变量配置 tools: 工具配置 - skill_ids: 技能 ID 列表 + skills: 技能列表 Returns: 包含 Pydantic 对象的字典 @@ -83,7 +83,7 @@ class AgentConfigConverter: "memory": MemoryConfig(enabled=True), "variables": [], "tools": [], - "skill_ids": [] + "skills": {} } # 1. 解析模型参数配置 @@ -124,7 +124,7 @@ class AgentConfigConverter: for name, tool_data in tools.items() } - if skill_ids: - result["skill_ids"] = [skill for skill in skill_ids] + if skills: + result["skills"] = SkillConfig(**skills) return result diff --git a/api/app/services/agent_config_helper.py b/api/app/services/agent_config_helper.py index ef6e22a4..08d28424 100644 --- a/api/app/services/agent_config_helper.py +++ b/api/app/services/agent_config_helper.py @@ -26,7 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: memory=agent_cfg.memory, variables=agent_cfg.variables, tools=agent_cfg.tools, - skill_ids=agent_cfg.skill_ids + skills=agent_cfg.skills ) # 将解析后的字段添加到对象上(用于序列化) @@ -35,6 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: agent_cfg.memory = parsed["memory"] agent_cfg.variables = parsed["variables"] agent_cfg.tools = parsed["tools"] - agent_cfg.skill_ids = parsed["skill_ids"] + agent_cfg.skills = parsed["skills"] return agent_cfg diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index fcd4bc79..3556bb88 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List from fastapi import Depends from sqlalchemy.orm import Session +from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -79,21 +80,55 @@ class AppChatService: # 获取工具服务 tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, workspace_id)) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) + elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search: + if web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + + # 加载技能关联的工具 + if hasattr(config, 'skills') and config.skills: + skills = config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval @@ -113,22 +148,6 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - # 获取模型参数 model_parameters = config.model_parameters @@ -246,20 +265,54 @@ class AppChatService: # 获取工具服务 tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, workspace_id)) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) + elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search: + if web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + + # 加载技能关联的工具 + if hasattr(config, 'skills') and config.skills: + skills = config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval @@ -279,22 +332,6 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - # 获取模型参数 model_parameters = config.model_parameters diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 1759206f..75044a5d 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -304,7 +304,7 @@ class AppService: memory=storage_data.get("memory"), variables=storage_data.get("variables", []), tools=storage_data.get("tools", []), - skill_ids=storage_data.get("skill_ids", []), + skills=storage_data.get("skills", {}), is_active=True, created_at=now, updated_at=now, @@ -908,7 +908,7 @@ class AppService: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: agent_cfg.tools = storage_data.get("tools", []) - agent_cfg.skill_ids = storage_data.get("skill_ids", []) + agent_cfg.skills = storage_data.get("skills", {}) agent_cfg.updated_at = now @@ -1001,6 +1001,7 @@ class AppService: }, variables=[], tools=[], + skills=[], is_active=True, created_at=now, updated_at=now, @@ -1219,6 +1220,7 @@ class AppService: "memory": agent_cfg.memory, "variables": agent_cfg.variables or [], "tools": agent_cfg.tools or [], + "skill_ids": agent_cfg.skill_ids or [], } # config = AgentConfigConverter.from_storage_format(agent_cfg) default_model_config_id = agent_cfg.default_model_config_id diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 0e0922bc..34e9f865 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -348,20 +348,23 @@ class DraftRunService: ) # 加载技能关联的工具 - if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids: - middleware = AgentMiddleware(skill_ids=agent_config.skill_ids) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + if hasattr(agent_config, 'skills') and agent_config.skills: + skills = agent_config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" # 添加知识库检索工具 if agent_config.knowledge_retrieval: @@ -610,21 +613,23 @@ class DraftRunService: ) # 加载技能关联的工具 - skill_configs = {} - if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids: - middleware = AgentMiddleware(skill_ids=agent_config.skill_ids) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + if hasattr(agent_config, 'skills') and agent_config.skills: + skills = agent_config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" # 添加知识库检索工具 diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py index a20e1b22..ea21b2ad 100644 --- a/api/app/services/skill_service.py +++ b/api/app/services/skill_service.py @@ -19,6 +19,14 @@ class SkillService: @staticmethod def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: """创建技能""" + # 检查同名技能 + existing = db.query(Skill).filter( + Skill.tenant_id == tenant_id, + Skill.name == data.name + ).first() + if existing: + raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME) + skill = SkillRepository.create(db, data, tenant_id) db.commit() db.refresh(skill) @@ -31,6 +39,22 @@ class SkillService: skill = SkillRepository.get_by_id(db, skill_id, tenant_id) if not skill: raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND) + + # 填充工具详情 + tool_service = ToolService(db) + enriched_tools = [] + for tool_config in skill.tools: + tool_id = tool_config.get("tool_id") + if tool_id: + tool_info = tool_service.get_tool_info(tool_id, tenant_id) + if tool_info: + enriched_tools.append({ + "tool_id": tool_id, + "operation": tool_config.get("operation"), + "tool_info": tool_info + }) + skill.tools = enriched_tools + return skill except (BusinessException, SQLAlchemyError) as e: db.rollback() diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 06549989..328e88e5 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -99,7 +99,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig: knowledge_retrieval=config_dict.get("knowledge_retrieval"), memory=config_dict.get("memory"), variables=config_dict.get("variables", []), - tools=config_dict.get("tools", {}), + tools=config_dict.get("tools", []), + skill_ids=config_dict.get("skill_ids", []) ) return agent_config