fix(skills): configuration modification

This commit is contained in:
Timebomb2018
2026-02-04 18:06:29 +08:00
parent 161da723b9
commit 71abd16ae7
11 changed files with 184 additions and 91 deletions

View File

@@ -10,14 +10,17 @@ from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skill_ids: Optional[List[str]] = None):
def __init__(self, skills: Optional[dict] = None):
"""
初始化中间件
Args:
skill_ids: 技能ID列表
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
"""
self.skill_ids = skill_ids or []
self.skills = skills or {}
self.enabled = self.skills.get('enabled', False)
self.all_skills = self.skills.get('all_skills', False)
self.skill_ids = self.skills.get('skill_ids', [])
@staticmethod
def filter_tools(
@@ -95,9 +98,17 @@ class AgentMiddleware:
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
skill_ids_to_load = []
if self.skill_ids:
for skill_id in self.skill_ids:
# 如果启用技能且 all_skills 为 True加载租户下所有激活的技能
if self.enabled and self.all_skills:
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
skill_ids_to_load = [str(skill.id) for skill in skills]
elif self.enabled and self.skill_ids:
skill_ids_to_load = self.skill_ids
if skill_ids_to_load:
for skill_id in skill_ids_to_load:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
@@ -110,7 +121,7 @@ class AgentMiddleware:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:

View File

@@ -29,8 +29,8 @@ class AgentConfig(Base):
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表")
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
# 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")

View File

@@ -82,6 +82,12 @@ class ToolConfig(BaseModel):
tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置")
class SkillConfig(BaseModel):
"""技能配置"""
enabled: bool = Field(default=True, description="是否启用该技能")
skill_ids: Optional[list[str]] = Field(default=list, description="技能ID列表")
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
class ToolOldConfig(BaseModel):
"""工具配置"""
@@ -157,7 +163,7 @@ class AgentConfigCreate(BaseModel):
)
# 技能配置
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
class AppCreate(BaseModel):
@@ -212,7 +218,7 @@ class AgentConfigUpdate(BaseModel):
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
# 技能配置
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
# ---------- Output Schemas ----------
@@ -272,7 +278,7 @@ class AgentConfig(BaseModel):
# 工具配置
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
skill_ids: Optional[List[str]] = []
skills: Optional[SkillConfig] = {}
is_active: bool
created_at: datetime.datetime

View File

@@ -1,5 +1,5 @@
"""Skill Schema 定义"""
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, field_serializer
import uuid
from datetime import datetime
@@ -32,10 +32,17 @@ class SkillUpdate(BaseModel):
is_public: Optional[bool] = None
class Skill(SkillBase):
class Skill(BaseModel):
"""Skill 响应 Schema"""
id: uuid.UUID
tenant_id: uuid.UUID
name: str
description: Optional[str] = None
tools: Union[List[Dict[str, Any]], List[Dict[str, str]]] = Field(default_factory=list, description="工具列表,可以是简单格式或包含工具详情")
config: Dict[str, Any] = Field(default_factory=dict)
prompt: Optional[str] = None
is_active: bool = True
is_public: bool = False
created_at: datetime
updated_at: datetime

View File

@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
VariableDefinition,
ToolConfig,
AgentConfigCreate,
AgentConfigUpdate, ToolOldConfig,
AgentConfigUpdate, ToolOldConfig, SkillConfig,
)
@@ -49,8 +49,8 @@ class AgentConfigConverter:
if hasattr(config, 'tools') and config.tools:
result["tools"] = [tool.model_dump() for tool in config.tools]
if hasattr(config, "skill_ids") and config.skill_ids:
result["skill_ids"] = [skill for skill in config.skill_ids]
if hasattr(config, "skills") and config.skills:
result["skills"] = config.skills.model_dump()
return result
@@ -61,7 +61,7 @@ class AgentConfigConverter:
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Union[list, Dict[str, Any]]],
skill_ids: Optional[list]
skills: Optional[dict]
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -72,7 +72,7 @@ class AgentConfigConverter:
memory: 记忆配置
variables: 变量配置
tools: 工具配置
skill_ids: 技能 ID 列表
skills: 技能列表
Returns:
包含 Pydantic 对象的字典
@@ -83,7 +83,7 @@ class AgentConfigConverter:
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": [],
"skill_ids": []
"skills": {}
}
# 1. 解析模型参数配置
@@ -124,7 +124,7 @@ class AgentConfigConverter:
for name, tool_data in tools.items()
}
if skill_ids:
result["skill_ids"] = [skill for skill in skill_ids]
if skills:
result["skills"] = SkillConfig(**skills)
return result

View File

@@ -26,7 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
skill_ids=agent_cfg.skill_ids
skills=agent_cfg.skills
)
# 将解析后的字段添加到对象上(用于序列化)
@@ -35,6 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
agent_cfg.skill_ids = parsed["skill_ids"]
agent_cfg.skills = parsed["skills"]
return agent_cfg

View File

@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -79,21 +80,55 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -113,22 +148,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -246,20 +265,54 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -279,22 +332,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters

View File

@@ -304,7 +304,7 @@ class AppService:
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skill_ids=storage_data.get("skill_ids", []),
skills=storage_data.get("skills", {}),
is_active=True,
created_at=now,
updated_at=now,
@@ -908,7 +908,7 @@ class AppService:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
agent_cfg.skills = storage_data.get("skills", {})
agent_cfg.updated_at = now
@@ -1001,6 +1001,7 @@ class AppService:
},
variables=[],
tools=[],
skills=[],
is_active=True,
created_at=now,
updated_at=now,
@@ -1219,6 +1220,7 @@ class AppService:
"memory": agent_cfg.memory,
"variables": agent_cfg.variables or [],
"tools": agent_cfg.tools or [],
"skill_ids": agent_cfg.skill_ids or [],
}
# config = AgentConfigConverter.from_storage_format(agent_cfg)
default_model_config_id = agent_cfg.default_model_config_id

View File

@@ -348,20 +348,23 @@ class DraftRunService:
)
# 加载技能关联的工具
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -610,21 +613,23 @@ class DraftRunService:
)
# 加载技能关联的工具
skill_configs = {}
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具

View File

@@ -19,6 +19,14 @@ class SkillService:
@staticmethod
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
# 检查同名技能
existing = db.query(Skill).filter(
Skill.tenant_id == tenant_id,
Skill.name == data.name
).first()
if existing:
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
skill = SkillRepository.create(db, data, tenant_id)
db.commit()
db.refresh(skill)
@@ -31,6 +39,22 @@ class SkillService:
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
# 填充工具详情
tool_service = ToolService(db)
enriched_tools = []
for tool_config in skill.tools:
tool_id = tool_config.get("tool_id")
if tool_id:
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
if tool_info:
enriched_tools.append({
"tool_id": tool_id,
"operation": tool_config.get("operation"),
"tool_info": tool_info
})
skill.tools = enriched_tools
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()

View File

@@ -99,7 +99,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
knowledge_retrieval=config_dict.get("knowledge_retrieval"),
memory=config_dict.get("memory"),
variables=config_dict.get("variables", []),
tools=config_dict.get("tools", {}),
tools=config_dict.get("tools", []),
skill_ids=config_dict.get("skill_ids", [])
)
return agent_config