feat(skills and model):
1. Add the "Skills" module; 2. The loading of the model square has been modified to be controlled through environment variables; 3. Dynamic scheduling of the skill binding tool; 4. Agent Integration Skills
This commit is contained in:
@@ -48,6 +48,9 @@ class AgentConfigConverter:
|
||||
# 5. 工具配置
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||
|
||||
if hasattr(config, "skill_ids") and config.skill_ids:
|
||||
result["skill_ids"] = [skill for skill in config.skill_ids]
|
||||
|
||||
return result
|
||||
|
||||
@@ -58,6 +61,7 @@ class AgentConfigConverter:
|
||||
memory: Optional[Dict[str, Any]],
|
||||
variables: Optional[list],
|
||||
tools: Optional[Union[list, Dict[str, Any]]],
|
||||
skill_ids: Optional[list]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
@@ -68,6 +72,7 @@ class AgentConfigConverter:
|
||||
memory: 记忆配置
|
||||
variables: 变量配置
|
||||
tools: 工具配置
|
||||
skill_ids: 技能 ID 列表
|
||||
|
||||
Returns:
|
||||
包含 Pydantic 对象的字典
|
||||
@@ -78,6 +83,7 @@ class AgentConfigConverter:
|
||||
"memory": MemoryConfig(enabled=True),
|
||||
"variables": [],
|
||||
"tools": [],
|
||||
"skill_ids": []
|
||||
}
|
||||
|
||||
# 1. 解析模型参数配置
|
||||
@@ -117,5 +123,8 @@ class AgentConfigConverter:
|
||||
name: ToolOldConfig(**tool_data)
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
|
||||
if skill_ids:
|
||||
result["skill_ids"] = [skill for skill in skill_ids]
|
||||
|
||||
return result
|
||||
|
||||
@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
memory=agent_cfg.memory,
|
||||
variables=agent_cfg.variables,
|
||||
tools=agent_cfg.tools,
|
||||
skill_ids=agent_cfg.skill_ids
|
||||
)
|
||||
|
||||
# 将解析后的字段添加到对象上(用于序列化)
|
||||
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
agent_cfg.memory = parsed["memory"]
|
||||
agent_cfg.variables = parsed["variables"]
|
||||
agent_cfg.tools = parsed["tools"]
|
||||
agent_cfg.skill_ids = parsed["skill_ids"]
|
||||
|
||||
return agent_cfg
|
||||
|
||||
@@ -304,6 +304,7 @@ class AppService:
|
||||
memory=storage_data.get("memory"),
|
||||
variables=storage_data.get("variables", []),
|
||||
tools=storage_data.get("tools", []),
|
||||
skill_ids=storage_data.get("skill_ids", []),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -907,6 +908,7 @@ class AppService:
|
||||
agent_cfg.variables = storage_data.get("variables", [])
|
||||
# if data.tools is not None:
|
||||
agent_cfg.tools = storage_data.get("tools", [])
|
||||
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
|
||||
|
||||
agent_cfg.updated_at = now
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class AppStatisticsService:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
@@ -10,6 +10,11 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -26,10 +31,8 @@ from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -310,6 +313,7 @@ class DraftRunService:
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
@@ -320,9 +324,7 @@ class DraftRunService:
|
||||
print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
@@ -345,6 +347,22 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
||||
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
@@ -558,6 +576,7 @@ class DraftRunService:
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
@@ -567,9 +586,7 @@ class DraftRunService:
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
@@ -592,6 +609,23 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
skill_configs = {}
|
||||
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
||||
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
@@ -628,7 +662,6 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_config["model_name"],
|
||||
|
||||
109
api/app/services/skill_service.py
Normal file
109
api/app/services/skill_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Skill Service"""
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
from app.schemas.skill_schema import SkillCreate, SkillUpdate
|
||||
from app.models.skill_model import Skill
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""Skill 业务逻辑层"""
|
||||
|
||||
@staticmethod
|
||||
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""创建技能"""
|
||||
skill = SkillRepository.create(db, data, tenant_id)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
return skill
|
||||
|
||||
@staticmethod
|
||||
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
|
||||
"""获取技能"""
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def list_skills(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
search: str = None,
|
||||
is_active: bool = None,
|
||||
is_public: bool = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> tuple[list[type[Skill]], int]:
|
||||
"""列出技能"""
|
||||
return SkillRepository.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""更新技能"""
|
||||
try:
|
||||
skill = SkillRepository.update(db, skill_id, data, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def delete_skill(db: Session, skill_id: uuid.UUID, workspace_id: uuid.UUID) -> bool:
|
||||
"""删除技能"""
|
||||
try:
|
||||
success = SkillRepository.delete(db, skill_id, workspace_id)
|
||||
if not success:
|
||||
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
|
||||
db.commit()
|
||||
return True
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
|
||||
"""加载技能关联的工具
|
||||
|
||||
Returns:
|
||||
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
|
||||
"""
|
||||
tools = []
|
||||
tool_to_skill_map = {} # {tool_name: skill_id}
|
||||
tool_service = ToolService(db)
|
||||
|
||||
for skill_id in skill_ids:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id))
|
||||
if skill and skill.is_active:
|
||||
# 加载技能关联的工具
|
||||
for tool_config in skill.tools:
|
||||
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool:
|
||||
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
# 建立工具到技能的映射
|
||||
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
|
||||
tool_to_skill_map[tool_name] = skill_id
|
||||
except Exception as e:
|
||||
print(f"加载技能 {skill_id} 的工具时出错: {e}")
|
||||
continue
|
||||
|
||||
return tools, tool_to_skill_map
|
||||
Reference in New Issue
Block a user