Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-01-07 19:35:39 +08:00
9 changed files with 621 additions and 1436 deletions

View File

@@ -77,7 +77,7 @@ class AppChatService:
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
@@ -109,20 +109,21 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
# web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -226,7 +227,7 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
if hasattr(config, 'tools') and config.tools:
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
@@ -258,20 +259,21 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
# web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters

View File

@@ -297,19 +297,35 @@ class DraftRunService:
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -507,7 +523,7 @@ class DraftRunService:
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
@@ -520,6 +536,22 @@ class DraftRunService:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from sqlalchemy.orm import Session
from app.core.tools.mcp import MCPClient
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
from app.repositories.tool_repository import (
ToolRepository, BuiltinToolRepository, CustomToolRepository,
MCPToolRepository, ToolExecutionRepository
@@ -42,6 +42,9 @@ class ToolService:
def __init__(self, db: Session):
self.db = db
self._tool_cache: Dict[str, BaseTool] = {}
# MCP管理器
self.mcp_tool_manager = MCPToolManager(db)
# 初始化仓储
self.tool_repo = ToolRepository()
@@ -675,23 +678,85 @@ class ToolService:
return []
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取MCP工具的方法"""
"""获取MCP工具的方法和参数"""
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return []
available_tools = mcp_config.available_tools or []
if not available_tools:
return []
# 如果没有工具列表,尝试同步
try:
success, tools, _ = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
mcp_config.available_tools = tool_names
self.db.commit()
available_tools = tool_names
except Exception as e:
logger.error(f"同步MCP工具列表失败: {e}")
return []
methods = []
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": [] # MCP工具参数需要动态获取
})
# 获取工具详细信息
try:
success, tools, _ = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")}
for tool_name in available_tools:
tool_info = tools_dict.get(tool_name, {})
# 解析工具参数
parameters = []
input_schema = tool_info.get("inputSchema", {})
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
for param_name, param_def in properties.items():
parameters.append({
"name": param_name,
"type": param_def.get("type", "string"),
"description": param_def.get("description", ""),
"required": param_name in required_fields,
"default": param_def.get("default"),
"enum": param_def.get("enum"),
"minimum": param_def.get("minimum"),
"maximum": param_def.get("maximum")
})
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": tool_info.get("description", f"MCP工具: {tool_name}"),
"parameters": parameters
})
else:
# 如果无法获取详细信息,返回基本信息
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": []
})
except Exception as e:
logger.error(f"获取MCP工具详细信息失败: {e}")
# 返回基本信息
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": []
})
return methods
@@ -812,10 +877,14 @@ class ToolService:
if not mcp_config:
return None
# 从配置中获取特定工具名称
tool_name = config.config_data.get("tool_name")
tool_config = {
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"available_tools": mcp_config.available_tools or []
"available_tools": mcp_config.available_tools or [],
"tool_name": tool_name # 指定具体工具
}
return MCPTool(str(config.id), tool_config)
@@ -1071,71 +1140,59 @@ class ToolService:
return {}
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试MCP连接"""
"""测试MCP连接并自动同步工具列表"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == config.id
).first()
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
# tools = await client.list_tools()
await client.disconnect()
# 更新连接状态
# 使用集成的MCP管理器测试连接
test_result = await self.mcp_tool_manager.test_tool_connection(
mcp_config.server_url, mcp_config.connection_config or {}
)
if test_result["success"]:
# 连接成功,自动同步工具列表
success, tools, error = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
self._update_tool_status(config)
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
return {
"success": True,
"message": "MCP连接成功",
# "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
"details": {"server_url": mcp_config.server_url}
"message": "MCP连接成功并同步工具列表",
"details": {
"server_url": mcp_config.server_url,
"tools_count": len(tool_names),
"tools": tool_names
}
}
except Exception as e:
await client.disconnect()
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
else:
return {"success": False, "message": f"同步工具失败: {error}"}
else:
# 更新连接失败状态
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
self._update_tool_status(config)
mcp_config.error_message = test_result.get("error", "连接失败")
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
return test_result
except Exception as e:
# 更新异常状态
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == config.id
).first()
if mcp_config:
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
return {"success": False, "message": f"测试失败: {str(e)}"}
@staticmethod
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
@@ -1190,57 +1247,44 @@ class ToolService:
# 创建MCP客户端
connection_config = mcp_config.connection_config or {}
client = SimpleMCPClient(mcp_config.server_url, connection_config)
client = MCPClient(mcp_config.server_url, connection_config)
if await client.connect():
try:
# 获取工具列表
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
await client.disconnect()
return {
"success": True,
"message": "工具列表同步成功",
"tools_count": len(tool_names),
"tools": tool_names
}
except Exception as e:
await client.disconnect()
# 更新错误状态
async with client:
# 获取工具列表
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
return {
"success": True,
"message": "工具列表同步成功",
"tools_count": len(tool_names),
"tools": tool_names
}
except Exception as e:
# 更新错误状态
try:
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if mcp_config:
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
else:
# 连接失败
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
except Exception as e:
except:
pass
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
return {"success": False, "message": f"同步失败: {str(e)}"}