Merge pull request #41 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

Feature/agent tool xjn
This commit is contained in:
Mark
2026-01-06 20:30:55 +08:00
committed by GitHub
15 changed files with 413 additions and 106 deletions

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
@@ -26,4 +27,9 @@ def get_workspace_list(
): ):
"""获取工作空间列表""" """获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id) workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功") return success(data=workspace_list, msg="工作空间列表获取成功")
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号"""
return success(data={"version": settings.SYSTEM_VERSION}, msg="系统版本获取成功")

View File

@@ -433,7 +433,8 @@ async def chat(
config=agent_config, config=agent_config,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
): ):
yield event yield event
@@ -469,7 +470,8 @@ async def chat(
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
) )
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT: elif app_type == AppType.MULTI_AGENT:
@@ -486,8 +488,8 @@ async def chat(
config=config, config=config,
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event

View File

@@ -154,7 +154,8 @@ async def chat(
config=agent_config, config=agent_config,
memory=memory, memory=memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
): ):
yield event yield event
@@ -178,7 +179,8 @@ async def chat(
web_search=web_search, web_search=web_search,
memory=memory, memory=memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
) )
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT: elif app_type == AppType.MULTI_AGENT:
@@ -195,8 +197,8 @@ async def chat(
config=config, config=config,
web_search=web_search, web_search=web_search,
memory=memory, memory=memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event
@@ -212,7 +214,6 @@ async def chat(
# 多 Agent 非流式返回 # 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat( result = await app_chat_service.multi_agent_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串 user_id=end_user_id, # 转换为字符串
@@ -291,4 +292,4 @@ async def chat(
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -7,7 +7,6 @@ LangChain Agent 封装
- 支持流式输出 - 支持流式输出
- 使用 RedBearLLM 支持多提供商 - 使用 RedBearLLM 支持多提供商
""" """
import os
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
@@ -97,8 +96,7 @@ class LangChainAgent:
"temperature": temperature, "temperature": temperature,
"streaming": streaming, "streaming": streaming,
"tool_count": len(self.tools), "tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [], "tool_names": [tool.name for tool in self.tools] if self.tools else []
"tool_count": len(self.tools)
} }
) )
@@ -139,8 +137,11 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
async def term_memory_save(self,messages,end_user_end,aimessages): async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j''' """
短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j
"""
end_user_end=f"Term_{end_user_end}" end_user_end=f"Term_{end_user_end}"
print(messages) print(messages)
print(aimessages) print(aimessages)
@@ -154,6 +155,7 @@ class LangChainAgent:
store.delete_duplicate_sessions() store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}') # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id return session_id
async def term_memory_redis_read(self,end_user_end): async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}" end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)

View File

@@ -164,6 +164,9 @@ class Settings:
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10")) TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v1.0.0")
def get_memory_output_path(self, filename: str = "") -> str: def get_memory_output_path(self, filename: str = "") -> str:
""" """

View File

@@ -191,10 +191,14 @@ class BaseTool(ABC):
execution_time=execution_time execution_time=execution_time
) )
def to_langchain_tool(self): def to_langchain_tool(self, operation: Optional[str] = None):
"""转换为Langchain工具格式""" """转换为Langchain工具格式
Args:
operation: 特定操作(适用于有操作的工具)
"""
from app.core.tools.langchain_adapter import LangchainAdapter from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self) return LangchainAdapter.convert_tool(self, operation)
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>" return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,216 @@
"""操作工具 - 为特定操作创建的工具包装器"""
from typing import List
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.models import ToolType
class OperationTool(BaseTool):
"""操作工具 - 包装基础工具的特定操作"""
def __init__(self, base_tool: BaseTool, operation: str):
self.base_tool = base_tool
self.operation = operation
super().__init__(base_tool.tool_id, base_tool.config)
@property
def name(self) -> str:
return f"{self.base_tool.name}_{self.operation}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
def description(self) -> str:
return f"{self.base_tool.description} - {self.operation}"
@property
def parameters(self) -> List[ToolParameter]:
"""返回特定操作的参数"""
if self.base_tool.name == 'datetime_tool':
return self._get_datetime_params()
elif self.base_tool.name == 'json_tool':
return self._get_json_params()
else:
# 默认返回除operation外的所有参数
return [p for p in self.base_tool.parameters if p.name != "operation"]
def _get_datetime_params(self) -> List[ToolParameter]:
"""获取datetime_tool特定操作的参数"""
if self.operation == "now":
return [
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "format":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "convert_timezone":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
elif self.operation == "timestamp_to_datetime":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
else:
return []
def _get_json_params(self) -> List[ToolParameter]:
"""获取json_tool特定操作的参数"""
base_params = [
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
)
]
if self.operation == "insert":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="new_value",
type=ParameterType.STRING,
description="新值用于insert操作",
required=True
)
]
elif self.operation == "replace":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="old_text",
type=ParameterType.STRING,
description="要替换的原文本用于replace操作",
required=True
),
ToolParameter(
name="new_text",
type=ParameterType.STRING,
description="替换后的新文本用于replace操作",
required=True
)
]
elif self.operation == "delete":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
elif self.operation == "parse":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
else:
return base_params
async def execute(self, **kwargs) -> ToolResult:
"""执行特定操作"""
# 添加operation参数
kwargs["operation"] = self.operation
return await self.base_tool.execute(**kwargs)

View File

@@ -1,4 +1,5 @@
"""自定义工具基类""" """自定义工具基类"""
import json
import time import time
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
import aiohttp import aiohttp
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
if not self.schema_content: if not self.schema_content:
return operations return operations
if isinstance(self.schema_content, str):
try:
self.schema_content = json.loads(self.schema_content)
except json.JSONDecodeError:
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
return operations
paths = self.schema_content.get("paths", {}) paths = self.schema_content.get("paths", {})

View File

@@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool):
name=tool_instance.name, name=tool_instance.name,
description=tool_instance.description, description=tool_instance.description,
args_schema=args_schema, args_schema=args_schema,
_tool_instance=tool_instance, tool_instance=tool_instance,
**kwargs **kwargs
) )
@@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool):
"""异步执行工具""" """异步执行工具"""
try: try:
# 执行内部工具 # 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs) result = await self.tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式 # 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result) return LangchainAdapter._format_result_for_langchain(result)
@@ -73,24 +73,39 @@ class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化""" """Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod @staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper: def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具 """将内部工具转换为Langchain工具
Args: Args:
tool: 内部工具实例 tool: 内部工具实例
operation: 特定操作(适用于有操作的工具)
Returns: Returns:
Langchain兼容的工具包装器 Langchain兼容的工具包装器
""" """
try: try:
wrapper = LangchainToolWrapper(tool_instance=tool) if operation and tool.name in ['datetime_tool', 'json_tool']:
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") # 为特定操作创建工具
return wrapper operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
else:
# 单个工具
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e: except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}") logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise raise
@staticmethod
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
"""为特定操作创建工具实例"""
from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation)
@staticmethod @staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具 """批量转换工具
@@ -110,7 +125,7 @@ class LangchainAdapter:
except Exception as e: except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}") logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具") logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
return converted_tools return converted_tools
@staticmethod @staticmethod
@@ -169,9 +184,10 @@ class LangchainAdapter:
"ToolArgsSchema", "ToolArgsSchema",
(BaseModel,), (BaseModel,),
{ {
"__module__": __name__,
"__annotations__": annotations, "__annotations__": annotations,
**fields, "model_config": {"extra": "forbid"},
"Config": type("Config", (), {"extra": "forbid"}) **fields
} }
) )

View File

@@ -16,14 +16,17 @@ logger = get_business_logger()
class MCPServiceManager: class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期""" """MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session): def __init__(self, db: Session = None):
"""初始化MCP服务管理器 """初始化MCP服务管理器
Args: Args:
db: 数据库会话 db: 数据库会话(可选)
""" """
self.db = db self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20) if db:
self.connection_pool = MCPConnectionPool(max_connections=20)
else:
self.connection_pool = None
# 服务状态管理 # 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
@@ -592,7 +595,7 @@ class MCPServiceManager:
except Exception as e: except Exception as e:
logger.error(f"清理失效服务失败: {e}") logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]: def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态""" """获取管理器状态"""
return { return {

View File

@@ -1,6 +1,6 @@
import datetime import datetime
import uuid import uuid
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict, Union
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
@@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel):
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
tool_id: str = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default_factory=dict, description="工具特定配置")
class ToolOldConfig(BaseModel):
"""工具配置""" """工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具") enabled: bool = Field(default=False, description="是否启用该工具")
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置") config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
@@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel):
) )
# 工具配置 # 工具配置
tools: Dict[str, ToolConfig] = Field( tools: List[ToolConfig] = Field(
default_factory=dict, default_factory=list,
description="工具配置key 为工具名称web_search, code_interpreter, image_generation 等)" description="Agent 可用的工具列表"
) )
@@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel):
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
# 工具配置 # 工具配置
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") tools: Optional[List[ToolConfig]] = Field(default=None, description="工具列表")
# ---------- Output Schemas ---------- # ---------- Output Schemas ----------
@@ -216,7 +222,7 @@ class AgentConfig(BaseModel):
variables: List[VariableDefinition] = [] variables: List[VariableDefinition] = []
# 工具配置 # 工具配置
tools: Dict[str, ToolConfig] = {} tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
is_active: bool is_active: bool
created_at: datetime.datetime created_at: datetime.datetime

View File

@@ -2,14 +2,14 @@
Agent 配置格式转换器 Agent 配置格式转换器
用于将 Pydantic 模型转换为数据库存储格式 用于将 Pydantic 模型转换为数据库存储格式
""" """
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, Union
from app.schemas.app_schema import ( from app.schemas.app_schema import (
KnowledgeRetrievalConfig, KnowledgeRetrievalConfig,
MemoryConfig, MemoryConfig,
VariableDefinition, VariableDefinition,
ToolConfig, ToolConfig,
AgentConfigCreate, AgentConfigCreate,
AgentConfigUpdate, AgentConfigUpdate, ToolOldConfig,
) )
@@ -47,10 +47,7 @@ class AgentConfigConverter:
# 5. 工具配置 # 5. 工具配置
if hasattr(config, 'tools') and config.tools: if hasattr(config, 'tools') and config.tools:
result["tools"] = { result["tools"] = [tool.model_dump() for tool in config.tools]
name: tool.model_dump()
for name, tool in config.tools.items()
}
return result return result
@@ -60,7 +57,7 @@ class AgentConfigConverter:
knowledge_retrieval: Optional[Dict[str, Any]], knowledge_retrieval: Optional[Dict[str, Any]],
memory: Optional[Dict[str, Any]], memory: Optional[Dict[str, Any]],
variables: Optional[list], variables: Optional[list],
tools: Optional[Dict[str, Any]], tools: Optional[Union[list, Dict[str, Any]]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
将数据库存储格式转换为 Pydantic 对象 将数据库存储格式转换为 Pydantic 对象
@@ -113,9 +110,12 @@ class AgentConfigConverter:
# 5. 解析工具配置 # 5. 解析工具配置
if tools: if tools:
result["tools"] = { if isinstance(tools, list):
name: ToolConfig(**tool_data) result["tools"] = [ToolConfig(**tool_config) for tool_config in tools]
for name, tool_data in tools.items() else:
} result["tools"] = {
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
return result return result

View File

@@ -14,6 +14,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
@@ -43,6 +47,7 @@ class AppChatService:
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
@@ -68,6 +73,24 @@ class AppChatService:
# 准备工具列表 # 准备工具列表
tools = [] tools = []
# 获取工具服务
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval: if knowledge_retrieval:
@@ -86,7 +109,7 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
web_tools = config.tools # web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {}) # web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False) # web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True: # if web_search == True:
@@ -173,6 +196,7 @@ class AppChatService:
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
@@ -199,6 +223,23 @@ class AppChatService:
# 准备工具列表 # 准备工具列表
tools = [] tools = []
# 获取工具服务
tool_service = ToolService(self.db)
if hasattr(config, 'tools') and config.tools:
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval: if knowledge_retrieval:
@@ -217,20 +258,20 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
web_tools = config.tools # web_tools = config.tools
web_search_choice = web_tools.get("web_search", {}) # web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False) # web_search_enable = web_search_choice.get("enabled", False)
if web_search == True: # if web_search == True:
if web_search_enable == True: # if web_search_enable == True:
search_tool = create_web_search_tool({}) # search_tool = create_web_search_tool({})
tools.append(search_tool) # tools.append(search_tool)
#
logger.debug( # logger.debug(
"已添加网络搜索工具", # "已添加网络搜索工具",
extra={ # extra={
"tool_count": len(tools) # "tool_count": len(tools)
} # }
) # )
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters

View File

@@ -307,7 +307,7 @@ class AppService:
knowledge_retrieval=storage_data.get("knowledge_retrieval"), knowledge_retrieval=storage_data.get("knowledge_retrieval"),
memory=storage_data.get("memory"), memory=storage_data.get("memory"),
variables=storage_data.get("variables", []), variables=storage_data.get("variables", []),
tools=storage_data.get("tools", {}), tools=storage_data.get("tools", []),
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -689,7 +689,7 @@ class AppService:
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None, knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
memory=source_config.memory.copy() if source_config.memory else None, memory=source_config.memory.copy() if source_config.memory else None,
variables=source_config.variables.copy() if source_config.variables else [], variables=source_config.variables.copy() if source_config.variables else [],
tools=source_config.tools.copy() if source_config.tools else {}, tools=source_config.tools.copy() if source_config.tools else [],
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -879,7 +879,7 @@ class AppService:
# if data.variables is not None: # if data.variables is not None:
agent_cfg.variables = storage_data.get("variables", []) agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None: # if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", {}) agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.updated_at = now agent_cfg.updated_at = now
@@ -966,7 +966,7 @@ class AppService:
"max_history": 10 "max_history": 10
}, },
variables=[], variables=[],
tools={}, tools=[],
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -1183,7 +1183,7 @@ class AppService:
"knowledge_retrieval": agent_cfg.knowledge_retrieval, "knowledge_retrieval": agent_cfg.knowledge_retrieval,
"memory": agent_cfg.memory, "memory": agent_cfg.memory,
"variables": agent_cfg.variables or [], "variables": agent_cfg.variables or [],
"tools": agent_cfg.tools or {}, "tools": agent_cfg.tools or [],
} }
# config = AgentConfigConverter.from_storage_format(agent_cfg) # config = AgentConfigConverter.from_storage_format(agent_cfg)
default_model_config_id = agent_cfg.default_model_config_id default_model_config_id = agent_cfg.default_model_config_id

View File

@@ -10,19 +10,22 @@ import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
from langchain.tools import tool from app.services.tool_service import ToolService
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger() logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel): class KnowledgeRetrievalInput(BaseModel):
@@ -291,24 +294,22 @@ class DraftRunService:
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
# 添加网络搜索工具 tool_service = ToolService(self.db)
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
if web_search_enable: # 从配置中获取启用的工具
logger.info("网络搜索已启用") if hasattr(agent_config, 'tools') and agent_config.tools:
# 创建网络搜索工具 for tool_config in agent_config.tools:
search_tool = create_web_search_tool(web_search_config) if tool_config.get("enabled", False):
tools.append(search_tool) # 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
logger.debug( ToolRepository.get_tenant_id_by_workspace_id(
"已添加网络搜索工具", self.db, str(workspace_id)))
extra={ if tool_instance:
"tool_count": len(tools) if tool_instance.name == "baidu_search_tool" and not web_search:
} continue
) # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
if agent_config.knowledge_retrieval: if agent_config.knowledge_retrieval:
@@ -503,24 +504,22 @@ class DraftRunService:
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
# 添加网络搜索工具 tool_service = ToolService(self.db)
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
if web_search_enable: # 从配置中获取启用的工具
logger.info("网络搜索已启用") if hasattr(agent_config, 'tools') and agent_config.tools:
# 创建网络搜索工具 for tool_config in agent_config.tools:
search_tool = create_web_search_tool(web_search_config) if tool_config.get("enabled", False):
tools.append(search_tool) # 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
logger.debug( ToolRepository.get_tenant_id_by_workspace_id(
"已添加网络搜索工具", self.db, str(workspace_id)))
extra={ if tool_instance:
"tool_count": len(tools) if tool_instance.name == "baidu_search_tool" and not web_search:
} continue
) # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
if agent_config.knowledge_retrieval: if agent_config.knowledge_retrieval: