fix(skills): configuration modification
This commit is contained in:
@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
|
||||
VariableDefinition,
|
||||
ToolConfig,
|
||||
AgentConfigCreate,
|
||||
AgentConfigUpdate, ToolOldConfig,
|
||||
AgentConfigUpdate, ToolOldConfig, SkillConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,8 +49,8 @@ class AgentConfigConverter:
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||
|
||||
if hasattr(config, "skill_ids") and config.skill_ids:
|
||||
result["skill_ids"] = [skill for skill in config.skill_ids]
|
||||
if hasattr(config, "skills") and config.skills:
|
||||
result["skills"] = config.skills.model_dump()
|
||||
|
||||
return result
|
||||
|
||||
@@ -61,7 +61,7 @@ class AgentConfigConverter:
|
||||
memory: Optional[Dict[str, Any]],
|
||||
variables: Optional[list],
|
||||
tools: Optional[Union[list, Dict[str, Any]]],
|
||||
skill_ids: Optional[list]
|
||||
skills: Optional[dict]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
@@ -72,7 +72,7 @@ class AgentConfigConverter:
|
||||
memory: 记忆配置
|
||||
variables: 变量配置
|
||||
tools: 工具配置
|
||||
skill_ids: 技能 ID 列表
|
||||
skills: 技能列表
|
||||
|
||||
Returns:
|
||||
包含 Pydantic 对象的字典
|
||||
@@ -83,7 +83,7 @@ class AgentConfigConverter:
|
||||
"memory": MemoryConfig(enabled=True),
|
||||
"variables": [],
|
||||
"tools": [],
|
||||
"skill_ids": []
|
||||
"skills": {}
|
||||
}
|
||||
|
||||
# 1. 解析模型参数配置
|
||||
@@ -124,7 +124,7 @@ class AgentConfigConverter:
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
|
||||
if skill_ids:
|
||||
result["skill_ids"] = [skill for skill in skill_ids]
|
||||
if skills:
|
||||
result["skills"] = SkillConfig(**skills)
|
||||
|
||||
return result
|
||||
|
||||
@@ -26,7 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
memory=agent_cfg.memory,
|
||||
variables=agent_cfg.variables,
|
||||
tools=agent_cfg.tools,
|
||||
skill_ids=agent_cfg.skill_ids
|
||||
skills=agent_cfg.skills
|
||||
)
|
||||
|
||||
# 将解析后的字段添加到对象上(用于序列化)
|
||||
@@ -35,6 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
agent_cfg.memory = parsed["memory"]
|
||||
agent_cfg.variables = parsed["variables"]
|
||||
agent_cfg.tools = parsed["tools"]
|
||||
agent_cfg.skill_ids = parsed["skill_ids"]
|
||||
agent_cfg.skills = parsed["skills"]
|
||||
|
||||
return agent_cfg
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -79,21 +80,55 @@ class AppChatService:
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -113,22 +148,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -246,20 +265,54 @@ class AppChatService:
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -279,22 +332,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
|
||||
@@ -304,7 +304,7 @@ class AppService:
|
||||
memory=storage_data.get("memory"),
|
||||
variables=storage_data.get("variables", []),
|
||||
tools=storage_data.get("tools", []),
|
||||
skill_ids=storage_data.get("skill_ids", []),
|
||||
skills=storage_data.get("skills", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -908,7 +908,7 @@ class AppService:
|
||||
agent_cfg.variables = storage_data.get("variables", [])
|
||||
# if data.tools is not None:
|
||||
agent_cfg.tools = storage_data.get("tools", [])
|
||||
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
|
||||
agent_cfg.skills = storage_data.get("skills", {})
|
||||
|
||||
agent_cfg.updated_at = now
|
||||
|
||||
@@ -1001,6 +1001,7 @@ class AppService:
|
||||
},
|
||||
variables=[],
|
||||
tools=[],
|
||||
skills=[],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -1219,6 +1220,7 @@ class AppService:
|
||||
"memory": agent_cfg.memory,
|
||||
"variables": agent_cfg.variables or [],
|
||||
"tools": agent_cfg.tools or [],
|
||||
"skill_ids": agent_cfg.skill_ids or [],
|
||||
}
|
||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||
default_model_config_id = agent_cfg.default_model_config_id
|
||||
|
||||
@@ -348,20 +348,23 @@ class DraftRunService:
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
||||
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
@@ -610,21 +613,23 @@ class DraftRunService:
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
skill_configs = {}
|
||||
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
||||
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
|
||||
@@ -19,6 +19,14 @@ class SkillService:
|
||||
@staticmethod
|
||||
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""创建技能"""
|
||||
# 检查同名技能
|
||||
existing = db.query(Skill).filter(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.name == data.name
|
||||
).first()
|
||||
if existing:
|
||||
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
skill = SkillRepository.create(db, data, tenant_id)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
@@ -31,6 +39,22 @@ class SkillService:
|
||||
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 填充工具详情
|
||||
tool_service = ToolService(db)
|
||||
enriched_tools = []
|
||||
for tool_config in skill.tools:
|
||||
tool_id = tool_config.get("tool_id")
|
||||
if tool_id:
|
||||
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
|
||||
if tool_info:
|
||||
enriched_tools.append({
|
||||
"tool_id": tool_id,
|
||||
"operation": tool_config.get("operation"),
|
||||
"tool_info": tool_info
|
||||
})
|
||||
skill.tools = enriched_tools
|
||||
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
|
||||
Reference in New Issue
Block a user