From b8f1095f53afce83be40c9fdf7298458c312b20b Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 4 Feb 2026 12:21:38 +0800 Subject: [PATCH] feat(skills and model): 1. Add the "Skills" module; 2. The loading of the model square has been modified to be controlled through environment variables; 3. Dynamic scheduling of the skill binding tool; 4. Agent Integration Skills --- api/app/controllers/__init__.py | 2 + api/app/controllers/skill_controller.py | 90 +++++++++++ api/app/core/agent/agent_middleware.py | 151 ++++++++++++++++++ api/app/core/config.py | 3 + .../core/models/scripts/bedrock_models.yaml | 1 - .../core/models/scripts/dashscope_models.yaml | 1 - api/app/core/models/scripts/loader.py | 33 +--- .../core/models/scripts/openai_models.yaml | 1 - api/app/main.py | 17 +- api/app/models/__init__.py | 4 +- api/app/models/agent_app_config_model.py | 1 + api/app/models/skill_model.py | 37 +++++ api/app/repositories/skill_repository.py | 111 +++++++++++++ api/app/schemas/app_schema.py | 8 + api/app/schemas/skill_schema.py | 57 +++++++ api/app/services/agent_config_converter.py | 9 ++ api/app/services/agent_config_helper.py | 2 + api/app/services/app_service.py | 2 + api/app/services/app_statistics_service.py | 2 +- api/app/services/draft_run_service.py | 55 +++++-- api/app/services/skill_service.py | 109 +++++++++++++ 21 files changed, 642 insertions(+), 54 deletions(-) create mode 100644 api/app/controllers/skill_controller.py create mode 100644 api/app/core/agent/agent_middleware.py create mode 100644 api/app/models/skill_model.py create mode 100644 api/app/repositories/skill_repository.py create mode 100644 api/app/schemas/skill_schema.py create mode 100644 api/app/services/skill_service.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 765ef967..a887458d 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -46,6 +46,7 @@ from . import ( memory_perceptual_controller, memory_working_controller, ontology_controller, + skill_controller ) # 创建管理端 API 路由器 @@ -92,5 +93,6 @@ manager_router.include_router(memory_perceptual_controller.router) manager_router.include_router(memory_working_controller.router) manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) +manager_router.include_router(skill_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/skill_controller.py b/api/app/controllers/skill_controller.py new file mode 100644 index 00000000..2308307b --- /dev/null +++ b/api/app/controllers/skill_controller.py @@ -0,0 +1,90 @@ +"""Skill Controller - 技能市场管理""" +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session +from typing import Optional +import uuid + +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard +from app.models import User +from app.schemas import skill_schema +from app.schemas.response_schema import PageData, PageMeta +from app.services.skill_service import SkillService +from app.core.response_utils import success + +router = APIRouter(prefix="/skills", tags=["Skills"]) + + +@router.post("", summary="创建技能") +@cur_workspace_access_guard() +def create_skill( + data: skill_schema.SkillCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建技能 - 可以关联现有工具(内置、MCP、自定义)""" + tenant_id = current_user.tenant_id + skill = SkillService.create_skill(db, data, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功") + + +@router.get("", summary="技能列表") +@cur_workspace_access_guard() +def list_skills( + search: Optional[str] = Query(None, description="搜索关键词"), + is_active: Optional[bool] = Query(None, description="是否激活"), + is_public: Optional[bool] = Query(None, description="是否公开"), + page: int = Query(1, ge=1, description="页码"), + pagesize: int = Query(10, ge=1, le=100, description="每页数量"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """技能市场列表 - 包含本工作空间和公开的技能""" + tenant_id = current_user.tenant_id + skills, total = SkillService.list_skills( + db, tenant_id, search, is_active, is_public, page, pagesize + ) + + items = [skill_schema.Skill.model_validate(s) for s in skills] + meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) + return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功") + + +@router.get("/{skill_id}", summary="获取技能详情") +@cur_workspace_access_guard() +def get_skill( + skill_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取技能详情""" + tenant_id = current_user.tenant_id + skill = SkillService.get_skill(db, skill_id, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功") + + +@router.put("/{skill_id}", summary="更新技能") +@cur_workspace_access_guard() +def update_skill( + skill_id: uuid.UUID, + data: skill_schema.SkillUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新技能""" + tenant_id = current_user.tenant_id + skill = SkillService.update_skill(db, skill_id, data, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功") + + +@router.delete("/{skill_id}", summary="删除技能") +@cur_workspace_access_guard() +def delete_skill( + skill_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除技能""" + tenant_id = current_user.tenant_id + SkillService.delete_skill(db, skill_id, tenant_id) + return success(msg="技能删除成功") diff --git a/api/app/core/agent/agent_middleware.py b/api/app/core/agent/agent_middleware.py new file mode 100644 index 00000000..ef5f7847 --- /dev/null +++ b/api/app/core/agent/agent_middleware.py @@ -0,0 +1,151 @@ +"""Agent Middleware - 动态技能过滤""" +import uuid +from typing import List, Dict, Any, Optional +from langchain_core.runnables import RunnablePassthrough + +from app.services.skill_service import SkillService +from app.repositories.skill_repository import SkillRepository + + +class AgentMiddleware: + """Agent 中间件 - 用于动态过滤和加载技能""" + + def __init__(self, skill_ids: Optional[List[str]] = None): + """ + 初始化中间件 + + Args: + skill_ids: 技能ID列表 + """ + self.skill_ids = skill_ids or [] + + @staticmethod + def filter_tools( + tools: List, + message: str = "", + skill_configs: Dict[str, Any] = None, + tool_to_skill_map: Dict[str, str] = None + ) -> tuple[List, List[str]]: + """ + 根据消息内容和技能配置动态过滤工具 + + Args: + tools: 所有可用工具列表 + message: 用户消息(可用于智能过滤) + skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}} + tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id} + + Returns: + (过滤后的工具列表, 激活的技能ID列表) + """ + if not tools: + return [], [] + + # 如果没有技能配置,返回所有工具 + if not skill_configs: + return tools, [] + + # 基于关键词匹配激活技能 + activated_skill_ids = [] + message_lower = message.lower() + + for skill_id, config in skill_configs.items(): + if not config.get('enabled', True): + continue + + keywords = config.get('keywords', []) + # 如果没有关键词限制,或消息包含关键词,则激活该技能 + if not keywords or any(kw.lower() in message_lower for kw in keywords): + activated_skill_ids.append(skill_id) + + # 如果没有工具映射关系,返回所有工具 + if not tool_to_skill_map: + return tools, activated_skill_ids + + # 根据激活的技能过滤工具 + filtered_tools = [] + for tool in tools: + tool_name = getattr(tool, 'name', str(id(tool))) + # 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留 + if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids: + filtered_tools.append(tool) + + return filtered_tools, activated_skill_ids + + def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]: + """ + 加载技能关联的工具 + + Args: + db: 数据库会话 + tenant_id: 租户id + base_tools: 基础工具列表 + + Returns: + (工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id}) + """ + + tools_dict = {} + tool_to_skill_map = {} # 工具名称到技能ID的映射 + + if base_tools: + for tool in base_tools: + tool_name = getattr(tool, 'name', str(id(tool))) + tools_dict[tool_name] = tool + # base_tools 不属于任何 skill,不加入映射 + + skill_configs = {} + + if self.skill_ids: + for skill_id in self.skill_ids: + try: + skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id) + if skill and skill.is_active: + # 保存技能配置(包含prompt) + config = skill.config or {} + config['prompt'] = skill.prompt + config['name'] = skill.name + skill_configs[skill_id] = config + except Exception: + continue + + # 加载技能工具并获取映射关系 + skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id) + + # 只添加不冲突的 skill_tools + for tool in skill_tools: + tool_name = getattr(tool, 'name', str(id(tool))) + if tool_name not in tools_dict: + tools_dict[tool_name] = tool + # 复制映射关系 + if tool_name in skill_tool_map: + tool_to_skill_map[tool_name] = skill_tool_map[tool_name] + + return list(tools_dict.values()), skill_configs, tool_to_skill_map + + @staticmethod + def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str: + """ + 根据激活的技能ID获取对应的提示词 + + Args: + activated_skill_ids: 被激活的技能ID列表 + skill_configs: 技能配置字典 + + Returns: + 合并后的提示词 + """ + prompts = [] + for skill_id in activated_skill_ids: + config = skill_configs.get(skill_id, {}) + prompt = config.get('prompt') + name = config.get('name', 'Skill') + if prompt: + prompts.append(f"# {name}\n{prompt}") + + return "\n\n".join(prompts) if prompts else "" + + @staticmethod + def create_runnable(): + """创建可运行的中间件""" + return RunnablePassthrough() diff --git a/api/app/core/config.py b/api/app/core/config.py index 0de957c7..bf721af9 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -215,6 +215,9 @@ class Settings: # official environment system version SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1") + # model square loading + LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true" + # workflow config WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index 453aaa13..e5b91d1c 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -1,5 +1,4 @@ provider: bedrock -enabled: false models: - name: ai21 type: llm diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index bcdb467e..df538e72 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -1,5 +1,4 @@ provider: dashscope -enabled: false models: - name: deepseek-r1-distill-qwen-14b type: llm diff --git a/api/app/core/models/scripts/loader.py b/api/app/core/models/scripts/loader.py index 6469656c..a14d3268 100644 --- a/api/app/core/models/scripts/loader.py +++ b/api/app/core/models/scripts/loader.py @@ -1,11 +1,11 @@ """模型配置加载器 - 用于将预定义模型批量导入到数据库""" -import os from pathlib import Path from typing import Callable import yaml from sqlalchemy.orm import Session + from app.models.models_model import ModelBase, ModelProvider @@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]: with open(config_file, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) - - # 检查是否需要加载(默认为 true) - if not data.get('enabled', True): - return [] - return data.get('models', []) -def _disable_yaml_config(provider: ModelProvider) -> None: - """将YAML文件的enabled标志设置为false""" - config_dir = Path(__file__).parent - config_file = config_dir / f"{provider.value}_models.yaml" - - if not config_file.exists(): - return - - with open(config_file, 'r', encoding='utf-8') as f: - data = yaml.safe_load(f) - - data['enabled'] = False - - with open(config_file, 'w', encoding='utf-8') as f: - yaml.dump(data, f, allow_unicode=True, sort_keys=False) - - def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict: """ 加载模型配置到数据库 @@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...") - - # provider_success = 0 + for model_data in models: try: # 检查模型是否已存在 @@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"更新成功: {model_data['name']}") result["success"] += 1 - # provider_success += 1 else: # 创建新模型 model = ModelBase(**model_data) @@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"添加成功: {model_data['name']}") result["success"] += 1 - # provider_success += 1 except Exception as e: db.rollback() if not silent: print(f"添加失败: {model_data['name']} - {str(e)}") result["failed"] += 1 - - # 如果该供应商的模型全部加载成功,将enabled设置为false - # if provider_success == len(models): - _disable_yaml_config(provider) return result diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 5a416264..68c63ee2 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -1,5 +1,4 @@ provider: openai -enabled: false models: - name: chatgpt-4o-latest type: llm diff --git a/api/app/main.py b/api/app/main.py index 7e16d2c0..e60c33f1 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -50,13 +50,16 @@ async def lifespan(app: FastAPI): logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)") # 加载预定义模型 - logger.info("开始加载预定义模型...") - try: - with get_db_context() as db: - result = load_models(db, silent=True) - logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") - except Exception as e: - logger.warning(f"加载预定义模型时出错: {str(e)}") + if settings.LOAD_MODEL: + logger.info("开始加载预定义模型...") + try: + with get_db_context() as db: + result = load_models(db, silent=True) + logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") + except Exception as e: + logger.warning(f"加载预定义模型时出错: {str(e)}") + else: + logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("应用程序启动完成") yield diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 984212de..daf03841 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -28,6 +28,7 @@ from .tool_model import ( ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) from .memory_perceptual_model import MemoryPerceptualModel +from .skill_model import Skill from .ontology_scene import OntologyScene from .ontology_class import OntologyClass from .ontology_scene import OntologyScene @@ -84,5 +85,6 @@ __all__ = [ "ExecutionStatus", "MemoryPerceptualModel", "ModelBase", - "LoadBalanceStrategy" + "LoadBalanceStrategy", + "Skill" ] diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 96752c8e..7ed70728 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -30,6 +30,7 @@ class AgentConfig(Base): memory = Column(JSON, nullable=True, comment="记忆配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置") tools = Column(JSON, default=dict, nullable=True, comment="工具配置") + skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表") # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") diff --git a/api/app/models/skill_model.py b/api/app/models/skill_model.py new file mode 100644 index 00000000..97fdeb03 --- /dev/null +++ b/api/app/models/skill_model.py @@ -0,0 +1,37 @@ +"""Skill 模型定义""" +import datetime +import uuid +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID, JSON + +from app.db import Base + + +class Skill(Base): + """技能模型 - 可以关联工具(内置、MCP、自定义)""" + __tablename__ = "skills" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + name = Column(String, nullable=False, comment="技能名称") + description = Column(Text, comment="技能描述") + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID") + + # 关联的工具 + tools = Column(JSON, default=list, comment="关联的工具列表") + + # 技能配置 + config = Column(JSON, default=dict, comment="技能配置") + + # 专属提示词 + prompt = Column(Text, comment="技能专属提示词") + + # 状态 + is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") + is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" diff --git a/api/app/repositories/skill_repository.py b/api/app/repositories/skill_repository.py new file mode 100644 index 00000000..6eeb7e08 --- /dev/null +++ b/api/app/repositories/skill_repository.py @@ -0,0 +1,111 @@ +"""Skill Repository""" +from typing import List, Optional, Tuple, Any +from sqlalchemy.orm import Session +from sqlalchemy import and_, or_ +import uuid + +from app.models.skill_model import Skill +from app.schemas.skill_schema import SkillCreate, SkillUpdate + + +class SkillRepository: + """Skill 数据访问层""" + + @staticmethod + def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: + """创建技能""" + skill = Skill( + **data.model_dump(), + tenant_id=tenant_id + ) + db.add(skill) + db.flush() + return skill + + @staticmethod + def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]: + """根据ID获取技能""" + query = db.query(Skill).filter(Skill.id == skill_id) + if tenant_id: + query = query.filter( + or_( + Skill.tenant_id == tenant_id, + Skill.is_public == True + ) + ) + return query.first() + + @staticmethod + def list_skills( + db: Session, + tenant_id: uuid.UUID, + search: Optional[str] = None, + is_active: Optional[bool] = None, + is_public: Optional[bool] = None, + page: int = 1, + pagesize: int = 10 + ) -> tuple[list[type[Skill]], int]: + """列出技能""" + filters = [ + or_( + Skill.tenant_id == tenant_id, + Skill.is_public == True + ) + ] + + if search: + filters.append( + or_( + Skill.name.ilike(f"%{search}%"), + # Skill.description.ilike(f"%{search}%") + ) + ) + + if is_active is not None: + filters.append(Skill.is_active == is_active) + + if is_public is not None: + filters.append(Skill.is_public == is_public) + + query = db.query(Skill).filter(and_(*filters)) + total = query.count() + + skills = query.order_by(Skill.created_at.desc()).offset( + (page - 1) * pagesize + ).limit(pagesize).all() + + return skills, total + + @staticmethod + def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]: + """更新技能""" + skill = db.query(Skill).filter( + Skill.id == skill_id, + Skill.tenant_id == tenant_id + ).first() + + if not skill: + return None + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(skill, key, value) + + db.flush() + return skill + + @staticmethod + def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: + """删除技能""" + skill = db.query(Skill).filter( + Skill.id == skill_id, + Skill.tenant_id == tenant_id + ).first() + + if not skill: + return False + + # db.delete(skill) + skill.is_active = False + db.flush() + return True diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 26d9b246..bcfeca57 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -156,6 +156,9 @@ class AgentConfigCreate(BaseModel): description="Agent 可用的工具列表" ) + # 技能配置 + skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表") + class AppCreate(BaseModel): name: str @@ -207,6 +210,9 @@ class AgentConfigUpdate(BaseModel): # 工具配置 tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表") + + # 技能配置 + skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表") # ---------- Output Schemas ---------- @@ -266,6 +272,8 @@ class AgentConfig(BaseModel): # 工具配置 tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] + skill_ids: Optional[List[str]] = [] + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime diff --git a/api/app/schemas/skill_schema.py b/api/app/schemas/skill_schema.py new file mode 100644 index 00000000..27f16b99 --- /dev/null +++ b/api/app/schemas/skill_schema.py @@ -0,0 +1,57 @@ +"""Skill Schema 定义""" +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field, field_serializer +import uuid +from datetime import datetime + + +class SkillBase(BaseModel): + """Skill 基础 Schema""" + name: str = Field(..., description="技能名称") + description: Optional[str] = Field(None, description="技能描述") + tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]") + config: Dict[str, Any] = Field(default_factory=dict, description="技能配置") + prompt: Optional[str] = Field(None, description="技能专属提示词") + is_active: bool = Field(True, description="是否激活") + is_public: bool = Field(False, description="是否公开到市场") + + +class SkillCreate(SkillBase): + """创建 Skill""" + pass + + +class SkillUpdate(BaseModel): + """更新 Skill""" + name: Optional[str] = None + description: Optional[str] = None + tools: Optional[List[Dict[str, str]]] = None + config: Optional[Dict[str, Any]] = None + prompt: Optional[str] = None + is_active: Optional[bool] = None + is_public: Optional[bool] = None + + +class Skill(SkillBase): + """Skill 响应 Schema""" + id: uuid.UUID + tenant_id: uuid.UUID + created_at: datetime + updated_at: datetime + + @field_serializer('created_at', 'updated_at') + def serialize_datetime_to_timestamp(self, value: datetime) -> int: + """(毫秒级)时间戳""" + return int(value.timestamp() * 1000) + + class Config: + from_attributes = True + + +class SkillQuery(BaseModel): + """Skill 查询参数""" + search: Optional[str] = None + is_active: Optional[bool] = None + is_public: Optional[bool] = None + page: int = Field(1, ge=1) + pagesize: int = Field(10, ge=1, le=100) diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 094aade8..ba76e299 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -48,6 +48,9 @@ class AgentConfigConverter: # 5. 工具配置 if hasattr(config, 'tools') and config.tools: result["tools"] = [tool.model_dump() for tool in config.tools] + + if hasattr(config, "skill_ids") and config.skill_ids: + result["skill_ids"] = [skill for skill in config.skill_ids] return result @@ -58,6 +61,7 @@ class AgentConfigConverter: memory: Optional[Dict[str, Any]], variables: Optional[list], tools: Optional[Union[list, Dict[str, Any]]], + skill_ids: Optional[list] ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -68,6 +72,7 @@ class AgentConfigConverter: memory: 记忆配置 variables: 变量配置 tools: 工具配置 + skill_ids: 技能 ID 列表 Returns: 包含 Pydantic 对象的字典 @@ -78,6 +83,7 @@ class AgentConfigConverter: "memory": MemoryConfig(enabled=True), "variables": [], "tools": [], + "skill_ids": [] } # 1. 解析模型参数配置 @@ -117,5 +123,8 @@ class AgentConfigConverter: name: ToolOldConfig(**tool_data) for name, tool_data in tools.items() } + + if skill_ids: + result["skill_ids"] = [skill for skill in skill_ids] return result diff --git a/api/app/services/agent_config_helper.py b/api/app/services/agent_config_helper.py index ae195913..ef6e22a4 100644 --- a/api/app/services/agent_config_helper.py +++ b/api/app/services/agent_config_helper.py @@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: memory=agent_cfg.memory, variables=agent_cfg.variables, tools=agent_cfg.tools, + skill_ids=agent_cfg.skill_ids ) # 将解析后的字段添加到对象上(用于序列化) @@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: agent_cfg.memory = parsed["memory"] agent_cfg.variables = parsed["variables"] agent_cfg.tools = parsed["tools"] + agent_cfg.skill_ids = parsed["skill_ids"] return agent_cfg diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 7ec4bc0e..1759206f 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -304,6 +304,7 @@ class AppService: memory=storage_data.get("memory"), variables=storage_data.get("variables", []), tools=storage_data.get("tools", []), + skill_ids=storage_data.get("skill_ids", []), is_active=True, created_at=now, updated_at=now, @@ -907,6 +908,7 @@ class AppService: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: agent_cfg.tools = storage_data.get("tools", []) + agent_cfg.skill_ids = storage_data.get("skill_ids", []) agent_cfg.updated_at = now diff --git a/api/app/services/app_statistics_service.py b/api/app/services/app_statistics_service.py index c164924a..1b6bc3b8 100644 --- a/api/app/services/app_statistics_service.py +++ b/api/app/services/app_statistics_service.py @@ -187,7 +187,7 @@ class AppStatisticsService: daily_tokens[date_str] = 0 daily_tokens[date_str] += int(tokens) - daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] + daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] total = sum(row["tokens"] for row in daily_data) return {"daily": daily_data, "total": total} diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index edad0123..0e0922bc 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,6 +10,11 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger from app.services.tool_service import ToolService from app.services.multimodal_service import MultimodalService -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session +from app.core.agent.agent_middleware import AgentMiddleware + logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -310,6 +313,7 @@ class DraftRunService: tools = [] tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): @@ -320,9 +324,7 @@ class DraftRunService: print(f"tool_config:{tool_config}") if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue @@ -345,6 +347,22 @@ class DraftRunService: } ) + # 加载技能关联的工具 + if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids: + middleware = AgentMiddleware(skill_ids=agent_config.skill_ids) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 添加知识库检索工具 if agent_config.knowledge_retrieval: kb_config = agent_config.knowledge_retrieval @@ -558,6 +576,7 @@ class DraftRunService: tools = [] tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): @@ -567,9 +586,7 @@ class DraftRunService: # print(f"tool_config:{tool_config}") if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue @@ -592,6 +609,23 @@ class DraftRunService: } ) + # 加载技能关联的工具 + skill_configs = {} + if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids: + middleware = AgentMiddleware(skill_ids=agent_config.skill_ids) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 添加知识库检索工具 if agent_config.knowledge_retrieval: @@ -628,7 +662,6 @@ class DraftRunService: } ) - # 4. 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_config["model_name"], diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py new file mode 100644 index 00000000..2ed2971e --- /dev/null +++ b/api/app/services/skill_service.py @@ -0,0 +1,109 @@ +"""Skill Service""" +import uuid +from typing import List + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from app.repositories.skill_repository import SkillRepository +from app.schemas.skill_schema import SkillCreate, SkillUpdate +from app.models.skill_model import Skill +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.services.tool_service import ToolService + + +class SkillService: + """Skill 业务逻辑层""" + + @staticmethod + def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: + """创建技能""" + skill = SkillRepository.create(db, data, tenant_id) + db.commit() + db.refresh(skill) + return skill + + @staticmethod + def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill: + """获取技能""" + try: + skill = SkillRepository.get_by_id(db, skill_id, tenant_id) + if not skill: + raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND) + return skill + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def list_skills( + db: Session, + tenant_id: uuid.UUID, + search: str = None, + is_active: bool = None, + is_public: bool = None, + page: int = 1, + pagesize: int = 10 + ) -> tuple[list[type[Skill]], int]: + """列出技能""" + return SkillRepository.list_skills( + db, tenant_id, search, is_active, is_public, page, pagesize + ) + + @staticmethod + def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill: + """更新技能""" + try: + skill = SkillRepository.update(db, skill_id, data, tenant_id) + if not skill: + raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND) + db.commit() + db.refresh(skill) + return skill + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def delete_skill(db: Session, skill_id: uuid.UUID, workspace_id: uuid.UUID) -> bool: + """删除技能""" + try: + success = SkillRepository.delete(db, skill_id, workspace_id) + if not success: + raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND) + db.commit() + return True + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]: + """加载技能关联的工具 + + Returns: + (tools, tool_to_skill_map) - 工具列表和工具到技能的映射 + """ + tools = [] + tool_to_skill_map = {} # {tool_name: skill_id} + tool_service = ToolService(db) + + for skill_id in skill_ids: + try: + skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id)) + if skill and skill.is_active: + # 加载技能关联的工具 + for tool_config in skill.tools: + tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool: + langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 建立工具到技能的映射 + tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool))) + tool_to_skill_map[tool_name] = skill_id + except Exception as e: + print(f"加载技能 {skill_id} 的工具时出错: {e}") + continue + + return tools, tool_to_skill_map