Merge branch 'feature/skill_zy' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/skill_zy
This commit is contained in:
@@ -10,14 +10,17 @@ from app.repositories.skill_repository import SkillRepository
|
|||||||
class AgentMiddleware:
|
class AgentMiddleware:
|
||||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||||
|
|
||||||
def __init__(self, skill_ids: Optional[List[str]] = None):
|
def __init__(self, skills: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
初始化中间件
|
初始化中间件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
skill_ids: 技能ID列表
|
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
||||||
"""
|
"""
|
||||||
self.skill_ids = skill_ids or []
|
self.skills = skills or {}
|
||||||
|
self.enabled = self.skills.get('enabled', False)
|
||||||
|
self.all_skills = self.skills.get('all_skills', False)
|
||||||
|
self.skill_ids = self.skills.get('skill_ids', [])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def filter_tools(
|
def filter_tools(
|
||||||
@@ -95,9 +98,17 @@ class AgentMiddleware:
|
|||||||
# base_tools 不属于任何 skill,不加入映射
|
# base_tools 不属于任何 skill,不加入映射
|
||||||
|
|
||||||
skill_configs = {}
|
skill_configs = {}
|
||||||
|
skill_ids_to_load = []
|
||||||
|
|
||||||
if self.skill_ids:
|
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
||||||
for skill_id in self.skill_ids:
|
if self.enabled and self.all_skills:
|
||||||
|
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
||||||
|
skill_ids_to_load = [str(skill.id) for skill in skills]
|
||||||
|
elif self.enabled and self.skill_ids:
|
||||||
|
skill_ids_to_load = self.skill_ids
|
||||||
|
|
||||||
|
if skill_ids_to_load:
|
||||||
|
for skill_id in skill_ids_to_load:
|
||||||
try:
|
try:
|
||||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||||
if skill and skill.is_active:
|
if skill and skill.is_active:
|
||||||
@@ -110,7 +121,7 @@ class AgentMiddleware:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 加载技能工具并获取映射关系
|
# 加载技能工具并获取映射关系
|
||||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
|
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
||||||
|
|
||||||
# 只添加不冲突的 skill_tools
|
# 只添加不冲突的 skill_tools
|
||||||
for tool in skill_tools:
|
for tool in skill_tools:
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ class AgentConfig(Base):
|
|||||||
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
||||||
memory = Column(JSON, nullable=True, comment="记忆配置")
|
memory = Column(JSON, nullable=True, comment="记忆配置")
|
||||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||||
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
|
||||||
skill_ids = Column(JSON, default=list, nullable=True, comment="关联的技能ID列表")
|
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
|
||||||
|
|
||||||
# 多 Agent 相关字段
|
# 多 Agent 相关字段
|
||||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||||
|
|||||||
@@ -82,6 +82,12 @@ class ToolConfig(BaseModel):
|
|||||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||||
|
|
||||||
|
class SkillConfig(BaseModel):
|
||||||
|
"""技能配置"""
|
||||||
|
enabled: bool = Field(default=True, description="是否启用该技能")
|
||||||
|
skill_ids: Optional[list[str]] = Field(default=list, description="技能ID列表")
|
||||||
|
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
|
||||||
|
|
||||||
|
|
||||||
class ToolOldConfig(BaseModel):
|
class ToolOldConfig(BaseModel):
|
||||||
"""工具配置"""
|
"""工具配置"""
|
||||||
@@ -157,7 +163,7 @@ class AgentConfigCreate(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 技能配置
|
# 技能配置
|
||||||
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
|
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||||
|
|
||||||
|
|
||||||
class AppCreate(BaseModel):
|
class AppCreate(BaseModel):
|
||||||
@@ -212,7 +218,7 @@ class AgentConfigUpdate(BaseModel):
|
|||||||
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
|
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
|
||||||
|
|
||||||
# 技能配置
|
# 技能配置
|
||||||
skill_ids: Optional[List[str]] = Field(default=None, description="关联的技能ID列表")
|
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||||
|
|
||||||
|
|
||||||
# ---------- Output Schemas ----------
|
# ---------- Output Schemas ----------
|
||||||
@@ -272,7 +278,7 @@ class AgentConfig(BaseModel):
|
|||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
||||||
|
|
||||||
skill_ids: Optional[List[str]] = []
|
skills: Optional[SkillConfig] = {}
|
||||||
|
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Skill Schema 定义"""
|
"""Skill Schema 定义"""
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, Union
|
||||||
from pydantic import BaseModel, Field, field_serializer
|
from pydantic import BaseModel, Field, field_serializer
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -32,10 +32,17 @@ class SkillUpdate(BaseModel):
|
|||||||
is_public: Optional[bool] = None
|
is_public: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class Skill(SkillBase):
|
class Skill(BaseModel):
|
||||||
"""Skill 响应 Schema"""
|
"""Skill 响应 Schema"""
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
tenant_id: uuid.UUID
|
tenant_id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
tools: Union[List[Dict[str, Any]], List[Dict[str, str]]] = Field(default_factory=list, description="工具列表,可以是简单格式或包含工具详情")
|
||||||
|
config: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
is_active: bool = True
|
||||||
|
is_public: bool = False
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
|
|||||||
VariableDefinition,
|
VariableDefinition,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
AgentConfigCreate,
|
AgentConfigCreate,
|
||||||
AgentConfigUpdate, ToolOldConfig,
|
AgentConfigUpdate, ToolOldConfig, SkillConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,8 +49,8 @@ class AgentConfigConverter:
|
|||||||
if hasattr(config, 'tools') and config.tools:
|
if hasattr(config, 'tools') and config.tools:
|
||||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||||
|
|
||||||
if hasattr(config, "skill_ids") and config.skill_ids:
|
if hasattr(config, "skills") and config.skills:
|
||||||
result["skill_ids"] = [skill for skill in config.skill_ids]
|
result["skills"] = config.skills.model_dump()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class AgentConfigConverter:
|
|||||||
memory: Optional[Dict[str, Any]],
|
memory: Optional[Dict[str, Any]],
|
||||||
variables: Optional[list],
|
variables: Optional[list],
|
||||||
tools: Optional[Union[list, Dict[str, Any]]],
|
tools: Optional[Union[list, Dict[str, Any]]],
|
||||||
skill_ids: Optional[list]
|
skills: Optional[dict]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
将数据库存储格式转换为 Pydantic 对象
|
将数据库存储格式转换为 Pydantic 对象
|
||||||
@@ -72,7 +72,7 @@ class AgentConfigConverter:
|
|||||||
memory: 记忆配置
|
memory: 记忆配置
|
||||||
variables: 变量配置
|
variables: 变量配置
|
||||||
tools: 工具配置
|
tools: 工具配置
|
||||||
skill_ids: 技能 ID 列表
|
skills: 技能列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含 Pydantic 对象的字典
|
包含 Pydantic 对象的字典
|
||||||
@@ -83,7 +83,7 @@ class AgentConfigConverter:
|
|||||||
"memory": MemoryConfig(enabled=True),
|
"memory": MemoryConfig(enabled=True),
|
||||||
"variables": [],
|
"variables": [],
|
||||||
"tools": [],
|
"tools": [],
|
||||||
"skill_ids": []
|
"skills": SkillConfig(enabled=False, all_skills=False, skill_ids=[])
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1. 解析模型参数配置
|
# 1. 解析模型参数配置
|
||||||
@@ -124,7 +124,9 @@ class AgentConfigConverter:
|
|||||||
for name, tool_data in tools.items()
|
for name, tool_data in tools.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
if skill_ids:
|
if skills:
|
||||||
result["skill_ids"] = [skill for skill in skill_ids]
|
result["skills"] = SkillConfig(**skills)
|
||||||
|
else:
|
||||||
|
result["skills"] = SkillConfig(enabled=False, all_skills=False, skill_ids=[])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
|||||||
memory=agent_cfg.memory,
|
memory=agent_cfg.memory,
|
||||||
variables=agent_cfg.variables,
|
variables=agent_cfg.variables,
|
||||||
tools=agent_cfg.tools,
|
tools=agent_cfg.tools,
|
||||||
skill_ids=agent_cfg.skill_ids
|
skills=agent_cfg.skills
|
||||||
)
|
)
|
||||||
|
|
||||||
# 将解析后的字段添加到对象上(用于序列化)
|
# 将解析后的字段添加到对象上(用于序列化)
|
||||||
@@ -35,6 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
|||||||
agent_cfg.memory = parsed["memory"]
|
agent_cfg.memory = parsed["memory"]
|
||||||
agent_cfg.variables = parsed["variables"]
|
agent_cfg.variables = parsed["variables"]
|
||||||
agent_cfg.tools = parsed["tools"]
|
agent_cfg.tools = parsed["tools"]
|
||||||
agent_cfg.skill_ids = parsed["skill_ids"]
|
agent_cfg.skills = parsed["skills"]
|
||||||
|
|
||||||
return agent_cfg
|
return agent_cfg
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.agent.agent_middleware import AgentMiddleware
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -79,21 +80,55 @@ class AppChatService:
|
|||||||
|
|
||||||
# 获取工具服务
|
# 获取工具服务
|
||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||||
for tool_config in config.tools:
|
for tool_config in config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
|
||||||
self.db, workspace_id))
|
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
continue
|
continue
|
||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||||
|
web_tools = config.tools
|
||||||
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search:
|
||||||
|
if web_search_enable:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载技能关联的工具
|
||||||
|
if hasattr(config, 'skills') and config.skills:
|
||||||
|
skills = config.skills
|
||||||
|
skill_enable = skills.get("enabled", False)
|
||||||
|
if skill_enable:
|
||||||
|
middleware = AgentMiddleware(skills=skills)
|
||||||
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
|
# 应用动态过滤
|
||||||
|
if skill_configs:
|
||||||
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||||
|
tool_to_skill_map)
|
||||||
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
|
activated_skill_ids, skill_configs
|
||||||
|
)
|
||||||
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.knowledge_retrieval
|
knowledge_retrieval = config.knowledge_retrieval
|
||||||
@@ -113,22 +148,6 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
|
||||||
web_tools = config.tools
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search:
|
|
||||||
if web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
@@ -246,20 +265,54 @@ class AppChatService:
|
|||||||
|
|
||||||
# 获取工具服务
|
# 获取工具服务
|
||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||||
|
|
||||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||||
for tool_config in config.tools:
|
for tool_config in config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
|
||||||
self.db, workspace_id))
|
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
continue
|
continue
|
||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||||
|
web_tools = config.tools
|
||||||
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search:
|
||||||
|
if web_search_enable:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载技能关联的工具
|
||||||
|
if hasattr(config, 'skills') and config.skills:
|
||||||
|
skills = config.skills
|
||||||
|
skill_enable = skills.get("enabled", False)
|
||||||
|
if skill_enable:
|
||||||
|
middleware = AgentMiddleware(skills=skills)
|
||||||
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
|
# 应用动态过滤
|
||||||
|
if skill_configs:
|
||||||
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||||
|
tool_to_skill_map)
|
||||||
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
|
activated_skill_ids, skill_configs
|
||||||
|
)
|
||||||
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.knowledge_retrieval
|
knowledge_retrieval = config.knowledge_retrieval
|
||||||
@@ -279,22 +332,6 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
|
||||||
web_tools = config.tools
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search:
|
|
||||||
if web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ class AppService:
|
|||||||
memory=storage_data.get("memory"),
|
memory=storage_data.get("memory"),
|
||||||
variables=storage_data.get("variables", []),
|
variables=storage_data.get("variables", []),
|
||||||
tools=storage_data.get("tools", []),
|
tools=storage_data.get("tools", []),
|
||||||
skill_ids=storage_data.get("skill_ids", []),
|
skills=storage_data.get("skills", {}),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -908,7 +908,7 @@ class AppService:
|
|||||||
agent_cfg.variables = storage_data.get("variables", [])
|
agent_cfg.variables = storage_data.get("variables", [])
|
||||||
# if data.tools is not None:
|
# if data.tools is not None:
|
||||||
agent_cfg.tools = storage_data.get("tools", [])
|
agent_cfg.tools = storage_data.get("tools", [])
|
||||||
agent_cfg.skill_ids = storage_data.get("skill_ids", [])
|
agent_cfg.skills = storage_data.get("skills", {})
|
||||||
|
|
||||||
agent_cfg.updated_at = now
|
agent_cfg.updated_at = now
|
||||||
|
|
||||||
@@ -1001,6 +1001,7 @@ class AppService:
|
|||||||
},
|
},
|
||||||
variables=[],
|
variables=[],
|
||||||
tools=[],
|
tools=[],
|
||||||
|
skills=[],
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -1219,6 +1220,7 @@ class AppService:
|
|||||||
"memory": agent_cfg.memory,
|
"memory": agent_cfg.memory,
|
||||||
"variables": agent_cfg.variables or [],
|
"variables": agent_cfg.variables or [],
|
||||||
"tools": agent_cfg.tools or [],
|
"tools": agent_cfg.tools or [],
|
||||||
|
"skills": agent_cfg.skills or {},
|
||||||
}
|
}
|
||||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||||
default_model_config_id = agent_cfg.default_model_config_id
|
default_model_config_id = agent_cfg.default_model_config_id
|
||||||
|
|||||||
@@ -348,20 +348,23 @@ class DraftRunService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 加载技能关联的工具
|
# 加载技能关联的工具
|
||||||
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||||
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
skills = agent_config.skills
|
||||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
skill_enable = skills.get("enabled", False)
|
||||||
tools.extend(skill_tools)
|
if skill_enable:
|
||||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
middleware = AgentMiddleware(skills=skills)
|
||||||
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
# 应用动态过滤
|
# 应用动态过滤
|
||||||
if skill_configs:
|
if skill_configs:
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
activated_skill_ids, skill_configs
|
activated_skill_ids, skill_configs
|
||||||
)
|
)
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
@@ -610,21 +613,23 @@ class DraftRunService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 加载技能关联的工具
|
# 加载技能关联的工具
|
||||||
skill_configs = {}
|
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||||
if hasattr(agent_config, 'skill_ids') and agent_config.skill_ids:
|
skills = agent_config.skills
|
||||||
middleware = AgentMiddleware(skill_ids=agent_config.skill_ids)
|
skill_enable = skills.get("enabled", False)
|
||||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
if skill_enable:
|
||||||
tools.extend(skill_tools)
|
middleware = AgentMiddleware(skills=skills)
|
||||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||||
|
tools.extend(skill_tools)
|
||||||
|
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||||
|
|
||||||
# 应用动态过滤
|
# 应用动态过滤
|
||||||
if skill_configs:
|
if skill_configs:
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
activated_skill_ids, skill_configs
|
activated_skill_ids, skill_configs
|
||||||
)
|
)
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
|
|||||||
@@ -19,6 +19,14 @@ class SkillService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||||
"""创建技能"""
|
"""创建技能"""
|
||||||
|
# 检查同名技能
|
||||||
|
existing = db.query(Skill).filter(
|
||||||
|
Skill.tenant_id == tenant_id,
|
||||||
|
Skill.name == data.name
|
||||||
|
).first()
|
||||||
|
if existing:
|
||||||
|
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
skill = SkillRepository.create(db, data, tenant_id)
|
skill = SkillRepository.create(db, data, tenant_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(skill)
|
db.refresh(skill)
|
||||||
@@ -31,6 +39,22 @@ class SkillService:
|
|||||||
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
||||||
if not skill:
|
if not skill:
|
||||||
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
# 填充工具详情
|
||||||
|
tool_service = ToolService(db)
|
||||||
|
enriched_tools = []
|
||||||
|
for tool_config in skill.tools:
|
||||||
|
tool_id = tool_config.get("tool_id")
|
||||||
|
if tool_id:
|
||||||
|
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
|
||||||
|
if tool_info:
|
||||||
|
enriched_tools.append({
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"operation": tool_config.get("operation"),
|
||||||
|
"tool_info": tool_info
|
||||||
|
})
|
||||||
|
skill.tools = enriched_tools
|
||||||
|
|
||||||
return skill
|
return skill
|
||||||
except (BusinessException, SQLAlchemyError) as e:
|
except (BusinessException, SQLAlchemyError) as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|||||||
@@ -99,7 +99,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
|
|||||||
knowledge_retrieval=config_dict.get("knowledge_retrieval"),
|
knowledge_retrieval=config_dict.get("knowledge_retrieval"),
|
||||||
memory=config_dict.get("memory"),
|
memory=config_dict.get("memory"),
|
||||||
variables=config_dict.get("variables", []),
|
variables=config_dict.get("variables", []),
|
||||||
tools=config_dict.get("tools", {}),
|
tools=config_dict.get("tools", []),
|
||||||
|
skill_ids=config_dict.get("skill_ids", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|||||||
30
api/migrations/versions/9b28b66cf8e8_202602041811.py
Normal file
30
api/migrations/versions/9b28b66cf8e8_202602041811.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""202602041811
|
||||||
|
|
||||||
|
Revision ID: 9b28b66cf8e8
|
||||||
|
Revises: e7c7afa249d1
|
||||||
|
Create Date: 2026-02-04 18:12:12.454402
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '9b28b66cf8e8'
|
||||||
|
down_revision: Union[str, None] = 'e7c7afa249d1'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column('agent_configs', 'skill_ids', new_column_name='skills', comment='技能配置')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column('agent_configs', 'skills', new_column_name='skill_ids', comment='关联的技能ID列表')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -37,7 +37,7 @@ function App() {
|
|||||||
const { checkJump } = useUser();
|
const { checkJump } = useUser();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const authToken = cookieUtils.get('authToken')
|
const authToken = cookieUtils.get('authToken')
|
||||||
if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/')) {
|
if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump')) {
|
||||||
window.location.href = `/#/login`;
|
window.location.href = `/#/login`;
|
||||||
} else {
|
} else {
|
||||||
checkJump()
|
checkJump()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-02 16:33:11
|
* @Date: 2026-02-02 16:33:11
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-04 14:06:45
|
* @Last Modified time: 2026-02-04 18:11:34
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Route Configuration
|
* Route Configuration
|
||||||
@@ -23,7 +23,6 @@ import { createHashRouter, createRoutesFromElements, Route } from 'react-router-
|
|||||||
/** Import route configuration JSON */
|
/** Import route configuration JSON */
|
||||||
import routesConfig from './routes.json';
|
import routesConfig from './routes.json';
|
||||||
|
|
||||||
|
|
||||||
/** Recursively collect all element names from routes */
|
/** Recursively collect all element names from routes */
|
||||||
function collectElements(routes: RouteConfig[]): Set<string> {
|
function collectElements(routes: RouteConfig[]): Set<string> {
|
||||||
const elements = new Set<string>();
|
const elements = new Set<string>();
|
||||||
@@ -91,6 +90,7 @@ const componentMap: Record<string, LazyExoticComponent<ComponentType<object>>> =
|
|||||||
Prompt: lazy(() => import('@/views/Prompt')),
|
Prompt: lazy(() => import('@/views/Prompt')),
|
||||||
Skills: lazy(() => import('@/views/Skills')),
|
Skills: lazy(() => import('@/views/Skills')),
|
||||||
SkillConfig: lazy(() => import('@/views/Skills/pages/SkillConfig')),
|
SkillConfig: lazy(() => import('@/views/Skills/pages/SkillConfig')),
|
||||||
|
Jump: lazy(() => import('@/views/JumpPage')),
|
||||||
Login: lazy(() => import('@/views/Login')),
|
Login: lazy(() => import('@/views/Login')),
|
||||||
InviteRegister: lazy(() => import('@/views/InviteRegister')),
|
InviteRegister: lazy(() => import('@/views/InviteRegister')),
|
||||||
NoPermission: lazy(() => import('@/views/NoPermission')),
|
NoPermission: lazy(() => import('@/views/NoPermission')),
|
||||||
|
|||||||
@@ -61,7 +61,8 @@
|
|||||||
{
|
{
|
||||||
"element": "NoAuthLayout",
|
"element": "NoAuthLayout",
|
||||||
"children": [
|
"children": [
|
||||||
{ "path": "/conversation/:token", "element": "Conversation" }
|
{ "path": "/conversation/:token", "element": "Conversation" },
|
||||||
|
{ "path": "/jump", "element": "Jump" }
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-02 16:33:54
|
* @Date: 2026-02-02 16:33:54
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-02 16:33:54
|
* @Last Modified time: 2026-02-04 18:30:10
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* User Store
|
* User Store
|
||||||
@@ -59,7 +59,8 @@ export interface UserState {
|
|||||||
export const whitePage = [
|
export const whitePage = [
|
||||||
'/conversation',
|
'/conversation',
|
||||||
'/login',
|
'/login',
|
||||||
'/invite-register'
|
'/invite-register',
|
||||||
|
'jump'
|
||||||
]
|
]
|
||||||
|
|
||||||
/** User store */
|
/** User store */
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ interface ApplicationModalProps {
|
|||||||
/**
|
/**
|
||||||
* Supported application types
|
* Supported application types
|
||||||
*/
|
*/
|
||||||
const types = [
|
export const types = [
|
||||||
'agent',
|
'agent',
|
||||||
'multi_agent',
|
'multi_agent',
|
||||||
'workflow'
|
'workflow'
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:34:12
|
* @Date: 2026-02-03 16:34:12
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-03 16:34:12
|
* @Last Modified time: 2026-02-04 18:57:35
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Application Management Page
|
* Application Management Page
|
||||||
@@ -12,12 +12,12 @@
|
|||||||
|
|
||||||
import React, { useState, useRef } from 'react';
|
import React, { useState, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { Button, Row, Col, App } from 'antd';
|
import { Button, Row, Col, App, Select } from 'antd';
|
||||||
import clsx from 'clsx';
|
import clsx from 'clsx';
|
||||||
import { DeleteOutlined } from '@ant-design/icons';
|
import { DeleteOutlined } from '@ant-design/icons';
|
||||||
|
|
||||||
import type { Application, ApplicationModalRef, Query } from './types';
|
import type { Application, ApplicationModalRef, Query } from './types';
|
||||||
import ApplicationModal from './components/ApplicationModal';
|
import ApplicationModal, { types } from './components/ApplicationModal';
|
||||||
import SearchInput from '@/components/SearchInput'
|
import SearchInput from '@/components/SearchInput'
|
||||||
import RbCard from '@/components/RbCard/Card'
|
import RbCard from '@/components/RbCard/Card'
|
||||||
import { getApplicationListUrl, deleteApplication } from '@/api/application'
|
import { getApplicationListUrl, deleteApplication } from '@/api/application'
|
||||||
@@ -65,10 +65,25 @@ const ApplicationManagement: React.FC = () => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
const handleChangeType = (value?: string) => {
|
||||||
|
setQuery(prev => ({...prev, type: value}))
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Row gutter={16} className="rb:mb-4">
|
<Row gutter={16} className="rb:mb-4">
|
||||||
<Col span={12}>
|
<Col span={3}>
|
||||||
|
<Select
|
||||||
|
placeholder={t('application.applicationType')}
|
||||||
|
options={types.map((type) => ({
|
||||||
|
value: type,
|
||||||
|
label: t(`application.${type}`),
|
||||||
|
}))}
|
||||||
|
allowClear
|
||||||
|
className="rb:w-full"
|
||||||
|
onChange={handleChangeType}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
<Col span={9}>
|
||||||
<SearchInput
|
<SearchInput
|
||||||
placeholder={t('application.searchPlaceholder')}
|
placeholder={t('application.searchPlaceholder')}
|
||||||
onSearch={(value) => setQuery({ search: value })}
|
onSearch={(value) => setQuery({ search: value })}
|
||||||
|
|||||||
53
web/src/views/JumpPage.tsx
Normal file
53
web/src/views/JumpPage.tsx
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
* @Author: ZhaoYing
|
||||||
|
* @Date: 2026-02-04 18:34:36
|
||||||
|
* @Last Modified by: ZhaoYing
|
||||||
|
* @Last Modified time: 2026-02-04 18:49:59
|
||||||
|
*/
|
||||||
|
import { useEffect, type FC } from 'react'
|
||||||
|
import { useNavigate, useSearchParams } from 'react-router-dom'
|
||||||
|
|
||||||
|
import { cookieUtils } from '@/utils/request'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JumpPage Component
|
||||||
|
*
|
||||||
|
* This is an intermediate redirect page used for OAuth authentication flow.
|
||||||
|
* It handles the callback from external authentication providers by:
|
||||||
|
* 1. Extracting authentication tokens from URL query parameters
|
||||||
|
* 2. Storing tokens in cookies for subsequent API requests
|
||||||
|
* 3. Redirecting users to their intended destination
|
||||||
|
*
|
||||||
|
* Expected URL format:
|
||||||
|
* /jump?access_token=xxx&refresh_token=yyy&target=/dashboard
|
||||||
|
*
|
||||||
|
* @returns null - This component doesn't render any UI, it only handles side effects
|
||||||
|
*/
|
||||||
|
const JumpPage: FC = () => {
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const [searchParams] = useSearchParams()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Convert URLSearchParams to a plain object for easier access
|
||||||
|
const data = Object.fromEntries(searchParams)
|
||||||
|
const { access_token, refresh_token, target } = data
|
||||||
|
|
||||||
|
// Store authentication tokens in cookies for API authorization
|
||||||
|
cookieUtils.set('authToken', access_token)
|
||||||
|
cookieUtils.set('refreshToken', refresh_token)
|
||||||
|
|
||||||
|
// Redirect to the target page if specified
|
||||||
|
if (target) {
|
||||||
|
// Use setTimeout to ensure cookie operations complete before navigation
|
||||||
|
setTimeout(() => {
|
||||||
|
// Replace current history entry to prevent users from going back to this page
|
||||||
|
navigate(target, { replace: true })
|
||||||
|
}, 0)
|
||||||
|
}
|
||||||
|
}, [searchParams, navigate])
|
||||||
|
|
||||||
|
// No UI rendering needed - this is a pure redirect handler
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
export default JumpPage
|
||||||
Reference in New Issue
Block a user