feat(skills and model):

1. Add the "Skills" module;
 2. The loading of the model square has been modified to be controlled through environment variables;
 3. Dynamic scheduling of the skill binding tool;
 4. Agent Integration Skills
This commit is contained in:
Timebomb2018
2026-02-04 12:21:38 +08:00
parent 3b5df793fb
commit b8f1095f53
21 changed files with 642 additions and 54 deletions

View File

@@ -48,6 +48,9 @@ class AgentConfigConverter:
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = [tool.model_dump() for tool in config.tools]
if hasattr(config, "skill_ids") and config.skill_ids:
result["skill_ids"] = [skill for skill in config.skill_ids]
return result
@@ -58,6 +61,7 @@ class AgentConfigConverter:
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Union[list, Dict[str, Any]]],
skill_ids: Optional[list]
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -68,6 +72,7 @@ class AgentConfigConverter:
memory: 记忆配置
variables: 变量配置
tools: 工具配置
skill_ids: 技能 ID 列表
Returns:
包含 Pydantic 对象的字典
@@ -78,6 +83,7 @@ class AgentConfigConverter:
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": [],
"skill_ids": []
}
# 1. 解析模型参数配置
@@ -117,5 +123,8 @@ class AgentConfigConverter:
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
if skill_ids:
result["skill_ids"] = [skill for skill in skill_ids]
return result

View File

@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
skill_ids=agent_cfg.skill_ids
)
# 将解析后的字段添加到对象上(用于序列化)
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
agent_cfg.skill_ids = parsed["skill_ids"]
return agent_cfg

View File

@@ -304,6 +304,7 @@ class AppService:
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skill_ids=storage_data.get("skill_ids", []),
is_active=True,
created_at=now,
updated_at=now,
@@ -907,6 +908,7 @@ class AppService:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
agent_cfg.updated_at = now

View File

@@ -187,7 +187,7 @@ class AppStatisticsService:
daily_tokens[date_str] = 0
daily_tokens[date_str] += int(tokens)
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["tokens"] for row in daily_data)
return {"daily": daily_data, "total": total}

View File

@@ -10,6 +10,11 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from app.services.multimodal_service import MultimodalService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -310,6 +313,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -320,9 +324,7 @@ class DraftRunService:
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -345,6 +347,22 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
@@ -558,6 +576,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -567,9 +586,7 @@ class DraftRunService:
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -592,6 +609,23 @@ class DraftRunService:
}
)
# 加载技能关联的工具
skill_configs = {}
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -628,7 +662,6 @@ class DraftRunService:
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],

View File

@@ -0,0 +1,109 @@
"""Skill Service"""
import uuid
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.repositories.skill_repository import SkillRepository
from app.schemas.skill_schema import SkillCreate, SkillUpdate
from app.models.skill_model import Skill
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.tool_service import ToolService
class SkillService:
"""Skill 业务逻辑层"""
@staticmethod
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
skill = SkillRepository.create(db, data, tenant_id)
db.commit()
db.refresh(skill)
return skill
@staticmethod
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
"""获取技能"""
try:
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: str = None,
is_active: bool = None,
is_public: bool = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
return SkillRepository.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
@staticmethod
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
"""更新技能"""
try:
skill = SkillRepository.update(db, skill_id, data, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
db.commit()
db.refresh(skill)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def delete_skill(db: Session, skill_id: uuid.UUID, workspace_id: uuid.UUID) -> bool:
"""删除技能"""
try:
success = SkillRepository.delete(db, skill_id, workspace_id)
if not success:
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
db.commit()
return True
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
"""加载技能关联的工具
Returns:
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
"""
tools = []
tool_to_skill_map = {} # {tool_name: skill_id}
tool_service = ToolService(db)
for skill_id in skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id))
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 建立工具到技能的映射
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
tool_to_skill_map[tool_name] = skill_id
except Exception as e:
print(f"加载技能 {skill_id} 的工具时出错: {e}")
continue
return tools, tool_to_skill_map