feat(apikey system): tool system development
This commit is contained in:
@@ -15,8 +15,13 @@ from langgraph.graph import StateGraph, START, END
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
TOOL_MANAGEMENT_AVAILABLE = True
|
||||
from app.db import get_db
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -434,3 +439,180 @@ async def execute_workflow_stream(
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
|
||||
# ==================== 工具管理系统集成 ====================
|
||||
|
||||
def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
"""获取工作流可用的工具列表
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
可用工具列表
|
||||
"""
|
||||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
logger.warning("工具管理系统不可用")
|
||||
return []
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import Session
|
||||
db = next(get_db())
|
||||
|
||||
# 创建工具注册表
|
||||
registry = ToolRegistry(db)
|
||||
|
||||
# 注册内置工具类
|
||||
from app.core.tools.builtin import (
|
||||
DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
)
|
||||
registry.register_tool_class(DateTimeTool)
|
||||
registry.register_tool_class(JsonTool)
|
||||
registry.register_tool_class(BaiduSearchTool)
|
||||
registry.register_tool_class(MinerUTool)
|
||||
registry.register_tool_class(TextInTool)
|
||||
|
||||
# 获取活跃的工具
|
||||
import uuid
|
||||
tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
|
||||
# 转换为Langchain工具
|
||||
langchain_tools = []
|
||||
for tool_info in active_tools:
|
||||
try:
|
||||
tool_instance = registry.get_tool(tool_info.id)
|
||||
if tool_instance:
|
||||
langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||||
langchain_tools.append(langchain_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
return langchain_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流工具失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
class ToolWorkflowNode:
|
||||
"""工具工作流节点 - 在工作流中执行工具"""
|
||||
|
||||
def __init__(self, node_config: dict, workflow_config: dict):
|
||||
"""初始化工具节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
self.node_config = node_config
|
||||
self.workflow_config = workflow_config
|
||||
self.tool_id = node_config.get("tool_id")
|
||||
self.tool_parameters = node_config.get("parameters", {})
|
||||
|
||||
async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
"""执行工具节点"""
|
||||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
logger.error("工具管理系统不可用")
|
||||
state["error"] = "工具管理系统不可用"
|
||||
return state
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import Session
|
||||
db = next(get_db())
|
||||
|
||||
# 创建工具执行器
|
||||
registry = ToolRegistry(db)
|
||||
executor = ToolExecutor(db, registry)
|
||||
|
||||
# 准备参数(支持变量替换)
|
||||
parameters = self._prepare_parameters(state)
|
||||
|
||||
# 执行工具
|
||||
result = await executor.execute_tool(
|
||||
tool_id=self.tool_id,
|
||||
parameters=parameters,
|
||||
user_id=uuid.UUID(state["user_id"]),
|
||||
workspace_id=uuid.UUID(state["workspace_id"])
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
node_id = self.node_config.get("id")
|
||||
if result.success:
|
||||
state["node_outputs"][node_id] = {
|
||||
"type": "tool",
|
||||
"tool_id": self.tool_id,
|
||||
"output": result.data,
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage
|
||||
}
|
||||
|
||||
# 更新运行时变量
|
||||
if isinstance(result.data, dict):
|
||||
for key, value in result.data.items():
|
||||
state["runtime_vars"][f"{node_id}.{key}"] = value
|
||||
else:
|
||||
state["runtime_vars"][f"{node_id}.result"] = result.data
|
||||
else:
|
||||
state["error"] = result.error
|
||||
state["error_node"] = node_id
|
||||
state["node_outputs"][node_id] = {
|
||||
"type": "tool",
|
||||
"tool_id": self.tool_id,
|
||||
"error": result.error,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具节点执行失败: {e}")
|
||||
state["error"] = str(e)
|
||||
state["error_node"] = self.node_config.get("id")
|
||||
return state
|
||||
|
||||
def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
"""准备工具参数(支持变量替换)"""
|
||||
parameters = {}
|
||||
|
||||
for key, value in self.tool_parameters.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# 变量替换
|
||||
var_path = value[2:-1]
|
||||
|
||||
# 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
if "." in var_path:
|
||||
parts = var_path.split(".")
|
||||
current = state.get("variables", {})
|
||||
|
||||
for part in parts:
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
# 尝试从运行时变量获取
|
||||
runtime_key = ".".join(parts)
|
||||
current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
break
|
||||
|
||||
parameters[key] = current
|
||||
else:
|
||||
# 简单变量
|
||||
variables = state.get("variables", {})
|
||||
parameters[key] = variables.get(var_path, value)
|
||||
else:
|
||||
parameters[key] = value
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
# 注册工具节点到NodeFactory(如果存在)
|
||||
try:
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
if hasattr(NodeFactory, 'register_node_type'):
|
||||
NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
||||
logger.info("工具节点已注册到工作流系统")
|
||||
except Exception as e:
|
||||
logger.warning(f"注册工具节点失败: {e}")
|
||||
Reference in New Issue
Block a user