feat(skills and model):

1. Add the "Skills" module;
 2. The loading of the model square has been modified to be controlled through environment variables;
 3. Dynamic scheduling of the skill binding tool;
 4. Agent Integration Skills
This commit is contained in:
Timebomb2018
2026-02-04 12:21:38 +08:00
parent 3b5df793fb
commit b8f1095f53
21 changed files with 642 additions and 54 deletions

View File

@@ -46,6 +46,7 @@ from . import (
memory_perceptual_controller,
memory_working_controller,
ontology_controller,
skill_controller
)
# 创建管理端 API 路由器
@@ -92,5 +93,6 @@ manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,90 @@
"""Skill Controller - 技能市场管理"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.schemas import skill_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.skill_service import SkillService
from app.core.response_utils import success
router = APIRouter(prefix="/skills", tags=["Skills"])
@router.post("", summary="创建技能")
@cur_workspace_access_guard()
def create_skill(
data: skill_schema.SkillCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建技能 - 可以关联现有工具内置、MCP、自定义"""
tenant_id = current_user.tenant_id
skill = SkillService.create_skill(db, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
@router.get("", summary="技能列表")
@cur_workspace_access_guard()
def list_skills(
search: Optional[str] = Query(None, description="搜索关键词"),
is_active: Optional[bool] = Query(None, description="是否激活"),
is_public: Optional[bool] = Query(None, description="是否公开"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""技能市场列表 - 包含本工作空间和公开的技能"""
tenant_id = current_user.tenant_id
skills, total = SkillService.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
items = [skill_schema.Skill.model_validate(s) for s in skills]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
@router.get("/{skill_id}", summary="获取技能详情")
@cur_workspace_access_guard()
def get_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取技能详情"""
tenant_id = current_user.tenant_id
skill = SkillService.get_skill(db, skill_id, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
@router.put("/{skill_id}", summary="更新技能")
@cur_workspace_access_guard()
def update_skill(
skill_id: uuid.UUID,
data: skill_schema.SkillUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新技能"""
tenant_id = current_user.tenant_id
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
@router.delete("/{skill_id}", summary="删除技能")
@cur_workspace_access_guard()
def delete_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除技能"""
tenant_id = current_user.tenant_id
SkillService.delete_skill(db, skill_id, tenant_id)
return success(msg="技能删除成功")

View File

@@ -0,0 +1,151 @@
"""Agent Middleware - 动态技能过滤"""
import uuid
from typing import List, Dict, Any, Optional
from langchain_core.runnables import RunnablePassthrough
from app.services.skill_service import SkillService
from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skill_ids: Optional[List[str]] = None):
"""
初始化中间件
Args:
skill_ids: 技能ID列表
"""
self.skill_ids = skill_ids or []
@staticmethod
def filter_tools(
tools: List,
message: str = "",
skill_configs: Dict[str, Any] = None,
tool_to_skill_map: Dict[str, str] = None
) -> tuple[List, List[str]]:
"""
根据消息内容和技能配置动态过滤工具
Args:
tools: 所有可用工具列表
message: 用户消息(可用于智能过滤)
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
Returns:
(过滤后的工具列表, 激活的技能ID列表)
"""
if not tools:
return [], []
# 如果没有技能配置,返回所有工具
if not skill_configs:
return tools, []
# 基于关键词匹配激活技能
activated_skill_ids = []
message_lower = message.lower()
for skill_id, config in skill_configs.items():
if not config.get('enabled', True):
continue
keywords = config.get('keywords', [])
# 如果没有关键词限制,或消息包含关键词,则激活该技能
if not keywords or any(kw.lower() in message_lower for kw in keywords):
activated_skill_ids.append(skill_id)
# 如果没有工具映射关系,返回所有工具
if not tool_to_skill_map:
return tools, activated_skill_ids
# 根据激活的技能过滤工具
filtered_tools = []
for tool in tools:
tool_name = getattr(tool, 'name', str(id(tool)))
# 如果工具不属于任何skillbase_tools或者工具所属的skill被激活则保留
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
filtered_tools.append(tool)
return filtered_tools, activated_skill_ids
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
"""
加载技能关联的工具
Args:
db: 数据库会话
tenant_id: 租户id
base_tools: 基础工具列表
Returns:
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
"""
tools_dict = {}
tool_to_skill_map = {} # 工具名称到技能ID的映射
if base_tools:
for tool in base_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
tools_dict[tool_name] = tool
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
if self.skill_ids:
for skill_id in self.skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 保存技能配置包含prompt
config = skill.config or {}
config['prompt'] = skill.prompt
config['name'] = skill.name
skill_configs[skill_id] = config
except Exception:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
if tool_name not in tools_dict:
tools_dict[tool_name] = tool
# 复制映射关系
if tool_name in skill_tool_map:
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
return list(tools_dict.values()), skill_configs, tool_to_skill_map
@staticmethod
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
"""
根据激活的技能ID获取对应的提示词
Args:
activated_skill_ids: 被激活的技能ID列表
skill_configs: 技能配置字典
Returns:
合并后的提示词
"""
prompts = []
for skill_id in activated_skill_ids:
config = skill_configs.get(skill_id, {})
prompt = config.get('prompt')
name = config.get('name', 'Skill')
if prompt:
prompts.append(f"# {name}\n{prompt}")
return "\n\n".join(prompts) if prompts else ""
@staticmethod
def create_runnable():
"""创建可运行的中间件"""
return RunnablePassthrough()

View File

@@ -215,6 +215,9 @@ class Settings:
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# model square loading
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))

View File

@@ -1,5 +1,4 @@
provider: bedrock
enabled: false
models:
- name: ai21
type: llm

View File

@@ -1,5 +1,4 @@
provider: dashscope
enabled: false
models:
- name: deepseek-r1-distill-qwen-14b
type: llm

View File

@@ -1,11 +1,11 @@
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
import os
from pathlib import Path
from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
# 检查是否需要加载(默认为 true
if not data.get('enabled', True):
return []
return data.get('models', [])
def _disable_yaml_config(provider: ModelProvider) -> None:
"""将YAML文件的enabled标志设置为false"""
config_dir = Path(__file__).parent
config_file = config_dir / f"{provider.value}_models.yaml"
if not config_file.exists():
return
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
data['enabled'] = False
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
"""
加载模型配置到数据库
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
# provider_success = 0
for model_data in models:
try:
# 检查模型是否已存在
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"更新成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
else:
# 创建新模型
model = ModelBase(**model_data)
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
if not silent:
print(f"添加成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
except Exception as e:
db.rollback()
if not silent:
print(f"添加失败: {model_data['name']} - {str(e)}")
result["failed"] += 1
# 如果该供应商的模型全部加载成功将enabled设置为false
# if provider_success == len(models):
_disable_yaml_config(provider)
return result

View File

@@ -1,5 +1,4 @@
provider: openai
enabled: false
models:
- name: chatgpt-4o-latest
type: llm

View File

@@ -50,13 +50,16 @@ async def lifespan(app: FastAPI):
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
# 加载预定义模型
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
if settings.LOAD_MODEL:
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
logger.info("应用程序启动完成")
yield

View File

@@ -28,6 +28,7 @@ from .tool_model import (
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
from .skill_model import Skill
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
@@ -84,5 +85,6 @@ __all__ = [
"ExecutionStatus",
"MemoryPerceptualModel",
"ModelBase",
"LoadBalanceStrategy"
"LoadBalanceStrategy",
"Skill"
]

View File

@@ -30,6 +30,7 @@ class AgentConfig(Base):
memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表")
# 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")

View File

@@ -0,0 +1,37 @@
"""Skill 模型定义"""
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSON
from app.db import Base
class Skill(Base):
"""技能模型 - 可以关联工具内置、MCP、自定义"""
__tablename__ = "skills"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
name = Column(String, nullable=False, comment="技能名称")
description = Column(Text, comment="技能描述")
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
# 关联的工具
tools = Column(JSON, default=list, comment="关联的工具列表")
# 技能配置
config = Column(JSON, default=dict, comment="技能配置")
# 专属提示词
prompt = Column(Text, comment="技能专属提示词")
# 状态
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<Skill(id={self.id}, name={self.name})>"

View File

@@ -0,0 +1,111 @@
"""Skill Repository"""
from typing import List, Optional, Tuple, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
import uuid
from app.models.skill_model import Skill
from app.schemas.skill_schema import SkillCreate, SkillUpdate
class SkillRepository:
"""Skill 数据访问层"""
@staticmethod
def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
skill = Skill(
**data.model_dump(),
tenant_id=tenant_id
)
db.add(skill)
db.flush()
return skill
@staticmethod
def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]:
"""根据ID获取技能"""
query = db.query(Skill).filter(Skill.id == skill_id)
if tenant_id:
query = query.filter(
or_(
Skill.tenant_id == tenant_id,
Skill.is_public == True
)
)
return query.first()
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: Optional[str] = None,
is_active: Optional[bool] = None,
is_public: Optional[bool] = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
filters = [
or_(
Skill.tenant_id == tenant_id,
Skill.is_public == True
)
]
if search:
filters.append(
or_(
Skill.name.ilike(f"%{search}%"),
# Skill.description.ilike(f"%{search}%")
)
)
if is_active is not None:
filters.append(Skill.is_active == is_active)
if is_public is not None:
filters.append(Skill.is_public == is_public)
query = db.query(Skill).filter(and_(*filters))
total = query.count()
skills = query.order_by(Skill.created_at.desc()).offset(
(page - 1) * pagesize
).limit(pagesize).all()
return skills, total
@staticmethod
def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]:
"""更新技能"""
skill = db.query(Skill).filter(
Skill.id == skill_id,
Skill.tenant_id == tenant_id
).first()
if not skill:
return None
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(skill, key, value)
db.flush()
return skill
@staticmethod
def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""删除技能"""
skill = db.query(Skill).filter(
Skill.id == skill_id,
Skill.tenant_id == tenant_id
).first()
if not skill:
return False
# db.delete(skill)
skill.is_active = False
db.flush()
return True

View File

@@ -156,6 +156,9 @@ class AgentConfigCreate(BaseModel):
description="Agent 可用的工具列表"
)
# 技能配置
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
class AppCreate(BaseModel):
name: str
@@ -207,6 +210,9 @@ class AgentConfigUpdate(BaseModel):
# 工具配置
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
# 技能配置
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
# ---------- Output Schemas ----------
@@ -266,6 +272,8 @@ class AgentConfig(BaseModel):
# 工具配置
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
skill_ids: Optional[List[str]] = []
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime

View File

@@ -0,0 +1,57 @@
"""Skill Schema 定义"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, field_serializer
import uuid
from datetime import datetime
class SkillBase(BaseModel):
"""Skill 基础 Schema"""
name: str = Field(..., description="技能名称")
description: Optional[str] = Field(None, description="技能描述")
tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]")
config: Dict[str, Any] = Field(default_factory=dict, description="技能配置")
prompt: Optional[str] = Field(None, description="技能专属提示词")
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开到市场")
class SkillCreate(SkillBase):
"""创建 Skill"""
pass
class SkillUpdate(BaseModel):
"""更新 Skill"""
name: Optional[str] = None
description: Optional[str] = None
tools: Optional[List[Dict[str, str]]] = None
config: Optional[Dict[str, Any]] = None
prompt: Optional[str] = None
is_active: Optional[bool] = None
is_public: Optional[bool] = None
class Skill(SkillBase):
"""Skill 响应 Schema"""
id: uuid.UUID
tenant_id: uuid.UUID
created_at: datetime
updated_at: datetime
@field_serializer('created_at', 'updated_at')
def serialize_datetime_to_timestamp(self, value: datetime) -> int:
"""(毫秒级)时间戳"""
return int(value.timestamp() * 1000)
class Config:
from_attributes = True
class SkillQuery(BaseModel):
"""Skill 查询参数"""
search: Optional[str] = None
is_active: Optional[bool] = None
is_public: Optional[bool] = None
page: int = Field(1, ge=1)
pagesize: int = Field(10, ge=1, le=100)

View File

@@ -48,6 +48,9 @@ class AgentConfigConverter:
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = [tool.model_dump() for tool in config.tools]
if hasattr(config, "skill_ids") and config.skill_ids:
result["skill_ids"] = [skill for skill in config.skill_ids]
return result
@@ -58,6 +61,7 @@ class AgentConfigConverter:
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Union[list, Dict[str, Any]]],
skill_ids: Optional[list]
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -68,6 +72,7 @@ class AgentConfigConverter:
memory: 记忆配置
variables: 变量配置
tools: 工具配置
skill_ids: 技能 ID 列表
Returns:
包含 Pydantic 对象的字典
@@ -78,6 +83,7 @@ class AgentConfigConverter:
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": [],
"skill_ids": []
}
# 1. 解析模型参数配置
@@ -117,5 +123,8 @@ class AgentConfigConverter:
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
if skill_ids:
result["skill_ids"] = [skill for skill in skill_ids]
return result

View File

@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
skill_ids=agent_cfg.skill_ids
)
# 将解析后的字段添加到对象上(用于序列化)
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
agent_cfg.skill_ids = parsed["skill_ids"]
return agent_cfg

View File

@@ -304,6 +304,7 @@ class AppService:
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skill_ids=storage_data.get("skill_ids", []),
is_active=True,
created_at=now,
updated_at=now,
@@ -907,6 +908,7 @@ class AppService:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
agent_cfg.updated_at = now

View File

@@ -187,7 +187,7 @@ class AppStatisticsService:
daily_tokens[date_str] = 0
daily_tokens[date_str] += int(tokens)
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["tokens"] for row in daily_data)
return {"daily": daily_data, "total": total}

View File

@@ -10,6 +10,11 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.tool_service import ToolService
from app.services.multimodal_service import MultimodalService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -310,6 +313,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -320,9 +324,7 @@ class DraftRunService:
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -345,6 +347,22 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
@@ -558,6 +576,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -567,9 +586,7 @@ class DraftRunService:
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -592,6 +609,23 @@ class DraftRunService:
}
)
# 加载技能关联的工具
skill_configs = {}
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -628,7 +662,6 @@ class DraftRunService:
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],

View File

@@ -0,0 +1,109 @@
"""Skill Service"""
import uuid
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.repositories.skill_repository import SkillRepository
from app.schemas.skill_schema import SkillCreate, SkillUpdate
from app.models.skill_model import Skill
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.tool_service import ToolService
class SkillService:
"""Skill 业务逻辑层"""
@staticmethod
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
skill = SkillRepository.create(db, data, tenant_id)
db.commit()
db.refresh(skill)
return skill
@staticmethod
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
"""获取技能"""
try:
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: str = None,
is_active: bool = None,
is_public: bool = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
return SkillRepository.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
@staticmethod
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
"""更新技能"""
try:
skill = SkillRepository.update(db, skill_id, data, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
db.commit()
db.refresh(skill)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def delete_skill(db: Session, skill_id: uuid.UUID, workspace_id: uuid.UUID) -> bool:
"""删除技能"""
try:
success = SkillRepository.delete(db, skill_id, workspace_id)
if not success:
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
db.commit()
return True
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
"""加载技能关联的工具
Returns:
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
"""
tools = []
tool_to_skill_map = {} # {tool_name: skill_id}
tool_service = ToolService(db)
for skill_id in skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id))
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 建立工具到技能的映射
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
tool_to_skill_map[tool_name] = skill_id
except Exception as e:
print(f"加载技能 {skill_id} 的工具时出错: {e}")
continue
return tools, tool_to_skill_map