Merge pull request #466 from SuanmoSuanyangTechnology/feature/agent-variables

Enhance workflow input handling and add legacy dify compatibility
This commit is contained in:
Mark
2026-03-05 14:21:31 +08:00
committed by GitHub
23 changed files with 418 additions and 922 deletions

View File

@@ -396,10 +396,10 @@ async def draft_run(
from app.models import AgentConfig, ModelConfig from app.models import AgentConfig, ModelConfig
from sqlalchemy import select from sqlalchemy import select
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
service = AppService(db) service = AppService(db)
draft_service = DraftRunService(db) draft_service = AgentRunService(db)
# 1. 验证应用 # 1. 验证应用
app = service._get_app_or_404(app_id) app = service._get_app_or_404(app_id)
@@ -484,8 +484,8 @@ async def draft_run(
} }
) )
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
draft_service = DraftRunService(db) draft_service = AgentRunService(db)
result = await draft_service.run( result = await draft_service.run(
agent_config=agent_cfg, agent_config=agent_cfg,
model_config=model_config, model_config=model_config,
@@ -789,8 +789,8 @@ async def draft_run_compare(
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
draft_service = DraftRunService(db) draft_service = AgentRunService(db)
async for event in draft_service.run_compare_stream( async for event in draft_service.run_compare_stream(
agent_config=agent_cfg, agent_config=agent_cfg,
models=model_configs, models=model_configs,
@@ -820,8 +820,8 @@ async def draft_run_compare(
) )
# 非流式返回 # 非流式返回
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
draft_service = DraftRunService(db) draft_service = AgentRunService(db)
result = await draft_service.run_compare( result = await draft_service.run_compare(
agent_config=agent_cfg, agent_config=agent_cfg,
models=model_configs, models=model_configs,

View File

@@ -21,6 +21,7 @@ from pydantic import BaseModel, Field
T = TypeVar("T") T = TypeVar("T")
class RedBearModelConfig(BaseModel): class RedBearModelConfig(BaseModel):
"""模型配置基类""" """模型配置基类"""
model_name: str model_name: str
@@ -35,6 +36,7 @@ class RedBearModelConfig(BaseModel):
concurrency: int = 5 # 并发限流 concurrency: int = 5 # 并发限流
extra_params: Dict[str, Any] = {} extra_params: Dict[str, Any] = {}
class RedBearModelFactory: class RedBearModelFactory:
"""模型工厂类""" """模型工厂类"""
@@ -154,7 +156,8 @@ class RedBearModelFactory:
else: else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
"""根据模型提供商获取对应的模型类""" """根据模型提供商获取对应的模型类"""
provider = config.provider.lower() provider = config.provider.lower()
@@ -183,10 +186,11 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
else: else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_embedding_class(provider: str) -> type[Embeddings]: def get_provider_embedding_class(provider: str) -> type[Embeddings]:
"""根据模型提供商获取对应的模型类""" """根据模型提供商获取对应的模型类"""
provider = provider.lower() provider = provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings return OpenAIEmbeddings
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
@@ -201,10 +205,11 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
else: else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_rerank_class(provider: str): def get_provider_rerank_class(provider: str):
"""根据模型提供商获取对应的模型类""" """根据模型提供商获取对应的模型类"""
provider = provider.lower() provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_community.document_compressors import JinaRerank from langchain_community.document_compressors import JinaRerank
return JinaRerank return JinaRerank
# elif provider == ModelProvider.OLLAMA: # elif provider == ModelProvider.OLLAMA:

View File

@@ -98,7 +98,7 @@ class DifyConverter(BaseConverter):
if not var_selector: if not var_selector:
return "" return ""
selector = var_selector.split('.') selector = var_selector.split('.')
if len(selector) not in [2, 3]: if len(selector) not in [2, 3] and var_selector != "context":
raise Exception(f"invalid variable selector: {var_selector}") raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3: if len(selector) == 3:
selector = selector[1:] selector = selector[1:]
@@ -332,7 +332,9 @@ class DifyConverter(BaseConverter):
messages.append( messages.append(
MessageConfig( MessageConfig(
role="user", role="user",
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"]) content=self.trans_variable_format(
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
)
) )
) )
vision = node_data["vision"]["enabled"] vision = node_data["vision"]["enabled"]

View File

@@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return True return True
def validate_config(self) -> bool: def validate_config(self) -> bool:
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'}) require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False

View File

@@ -303,30 +303,44 @@ class VariablePool:
""" """
return self._get_variable_struct(selector) is not None return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]: def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量 """获取所有系统变量
Returns: Returns:
系统变量字典 系统变量字典
""" """
sys_namespace = self.variables.get("sys", {}) sys_namespace = self.variables.get("sys", {})
if literal:
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
return {k: v.instance.get_value() for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]: def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
"""获取所有会话变量 """获取所有会话变量
Returns: Returns:
会话变量字典 会话变量字典
""" """
conv_namespace = self.variables.get("conv", {}) conv_namespace = self.variables.get("conv", {})
if literal:
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
return {k: v.instance.get_value() for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]: def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
"""获取所有节点输出(运行时变量) """获取所有节点输出(运行时变量)
Returns: Returns:
节点输出字典,键为节点 ID 节点输出字典,键为节点 ID
""" """
if literal:
runtime_vars = {
namespace: {
k: v.instance.to_literal()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
else:
runtime_vars = { runtime_vars = {
namespace: { namespace: {
k: v.instance.get_value() k: v.instance.get_value()

View File

@@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db from app.db import get_db
from app.models import AppRelease from app.models import AppRelease
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
"""准备 Agent公共逻辑 """准备 Agent公共逻辑
Args: Args:
@@ -65,7 +65,7 @@ class AgentNode(BaseNode):
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db) draft_service = AgentRunService(db)
return draft_service, release, message return draft_service, release, message

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
@@ -10,8 +11,10 @@ from app.core.config import settings
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.services.multimodal_service import PROVIDER_STRATEGIES from app.db import get_db_read
from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -548,9 +551,9 @@ class BaseNode(ABC):
return render_template( return render_template(
template=template, template=template,
conv_vars=variable_pool.get_all_conversation_vars(), conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(), node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(), system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict strict=strict
) )
@@ -614,16 +617,32 @@ class BaseNode(ABC):
return variable_pool.has(selector) return variable_pool.has(selector)
@staticmethod @staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None: async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
if isinstance(content, str): if isinstance(content, str):
if enable_file: if enable_file:
return {"text": content} return {"text": content}
return content return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]() elif isinstance(content, FileObject):
result = await trans_tool.format_image(content["url"]) if content.content_cache.get(provider):
return result return content.content_cache[provider]
raise TypeError('Unexpect input value type') with get_db_read() as db:
multimodel_service = MultimodalService(db, provider)
message = await multimodel_service.process_files(
[FileInput.model_construct(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
file_type=content.origin_file_type,
upload_file_id=content.file_id
)]
)
if message:
content.content_cache[provider] = message[0]
return message[0]
return None
raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod @staticmethod
def process_model_output(content) -> str: def process_model_output(content) -> str:

View File

@@ -91,8 +91,8 @@ class IterationRuntime:
return loopstate return loopstate
def merge_conv_vars(self): def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.get_all_conversation_vars() self.child_variable_pool.variables["conv"]
) )
async def run_task(self, item, idx): async def run_task(self, item, idx):

View File

@@ -156,7 +156,7 @@ class LoopRuntime:
def merge_conv_vars(self, loopstate): def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {}) self.child_variable_pool.variables["conv"]
) )
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars loopstate["node_outputs"][self.node_id] = loop_vars

View File

@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_value(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files: for file in files.value:
content = await self.process_message(provider, file, self.typed_config.vision) content = await self.process_message(provider, file.value, self.typed_config.vision)
if content: if content:
file_content.append(content) file_content.append(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from app.schemas import FileType from app.schemas import FileType
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
return cls.NUMBER return cls.NUMBER
elif isinstance(var, bool): elif isinstance(var, bool):
return cls.BOOLEAN return cls.BOOLEAN
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')): elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE return cls.FILE
elif isinstance(var, dict): elif isinstance(var, dict):
return cls.OBJECT return cls.OBJECT
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
class FileObject(BaseModel): class FileObject(BaseModel):
type: FileType type: FileType
url: str url: str
__file: bool transfer_method: str
origin_file_type: str
file_id: str | None
content_cache: dict = Field(default_factory=dict)
is_file: bool
class BaseVariable(ABC): class BaseVariable(ABC):

View File

@@ -63,13 +63,16 @@ class FileVariable(BaseVariable):
def valid_value(self, value) -> FileObject: def valid_value(self, value) -> FileObject:
if isinstance(value, dict): if isinstance(value, dict):
if not value.get("__file"): if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject( return FileObject(
**{ **{
"type": str(value.get('type')), "type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'), "url": value.get('url'),
"__file": True "file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
} }
) )
if isinstance(value, FileObject): if isinstance(value, FileObject):

View File

@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
return datetime.datetime.now() > self.expires_at return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at') @field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
@classmethod def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳""" """将datetime转换为时间戳"""
return datetime_to_timestamp(v) return datetime_to_timestamp(v)
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
@field_serializer('last_used_at') @field_serializer('last_used_at')
@classmethod def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳""" """将datetime转换为时间戳"""
return datetime_to_timestamp(v) return datetime_to_timestamp(v)
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
created_at: datetime.datetime created_at: datetime.datetime
@field_serializer('created_at') @field_serializer('created_at')
@classmethod def serialize_datetime(self, v: datetime.datetime) -> int:
def serialize_datetime(cls, v: datetime.datetime) -> int:
"""将datetime转换为时间戳""" """将datetime转换为时间戳"""
return datetime_to_timestamp(v) return datetime_to_timestamp(v)

View File

@@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel):
class MultiAgentConfigCreate(BaseModel): class MultiAgentConfigCreate(BaseModel):
"""创建多 Agent 配置""" """创建多 Agent 配置"""
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
orchestration_mode: str = Field( orchestration_mode: str = Field(
default="collaboration", default="collaboration",
pattern="^(collaboration|supervisor)$", pattern="^(collaboration|supervisor)$",
description="协作模式collaboration协作| supervisor监督" description="协作模式collaboration协作| supervisor监督"
) )
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则")
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
aggregation_strategy: str = Field( aggregation_strategy: str = Field(
default="merge", default="merge",
@@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel):
class MultiAgentConfigUpdate(BaseModel): class MultiAgentConfigUpdate(BaseModel):
"""更新多 Agent 配置""" """更新多 Agent 配置"""
master_agent_id: Optional[uuid.UUID] = None master_agent_id: Optional[uuid.UUID] = None
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID") default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
model_parameters: Optional[ModelParameters] = Field( model_parameters: Optional[ModelParameters] = Field(
None, None,

View File

@@ -263,8 +263,8 @@ def create_agent_invocation_tool(
try: try:
# 9. 调用 Agent # 9. 调用 Agent
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
draft_service = DraftRunService(db) draft_service = AgentRunService(db)
result = await draft_service.run( result = await draft_service.run(
agent_config=agent_config, agent_config=agent_config,

View File

@@ -10,25 +10,24 @@ from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig from app.models import MultiAgentConfig, AgentConfig
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
AgentRunService
from app.services.draft_run_service import create_web_search_tool from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.workflow_service import WorkflowService
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger() logger = get_business_logger()
@@ -39,6 +38,8 @@ class AppChatService:
def __init__(self, db: Session): def __init__(self, db: Session):
self.db = db self.db = db
self.conversation_service = ConversationService(db) self.conversation_service = ConversationService(db)
self.agent_service = AgentRunService(db)
self.workflow_service = WorkflowService(db)
async def agnet_chat( async def agnet_chat(
self, self,
@@ -55,12 +56,10 @@ class AppChatService:
files: Optional[List[FileInput]] = None # 新增:多模态文件 files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
start_time = time.time() start_time = time.time()
config_id = None config_id = None
if variables is None: variables = self.agent_service.prepare_variables(variables, config.variables)
variables = {}
# 获取模型配置ID # 获取模型配置ID
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
@@ -79,74 +78,20 @@ class AppChatService:
tools = [] tools = []
# 获取工具服务 # 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具 tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools) tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具") if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
# 应用动态过滤 tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag = False memory_flag = False
if memory == True: if memory:
memory_config = config.memory memory_tools, memory_flag = self.agent_service.load_memory_config(
if memory_config.get("enabled") and user_id: config.memory, user_id, storage_type, user_rag_memory_id
memory_flag = True )
memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.extend(memory_tools)
tools.append(memory_tool)
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters
@@ -246,10 +191,9 @@ class AppChatService:
try: try:
start_time = time.time() start_time = time.time()
config_id = None config_id = None
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
if variables is None: variables = self.agent_service.prepare_variables(variables, config.variables)
variables = {}
# 获取模型配置ID # 获取模型配置ID
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
@@ -267,73 +211,22 @@ class AppChatService:
tools = [] tools = []
# 获取工具服务 # 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug( skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools) tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具") if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
# 应用动态过滤 tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
if memory: if memory:
memory_config = config.memory memory_tools, memory_flag = self.agent_service.load_memory_config(
if memory_config.get("enabled") and user_id: config.memory, user_id, storage_type, user_rag_memory_id
memory_flag = True )
memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.extend(memory_tools)
tools.append(memory_tool)
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters
@@ -372,9 +265,6 @@ class AppChatService:
processed_files = await multimodal_service.process_files(files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent支持多模态 # 流式调用 Agent支持多模态
full_content = "" full_content = ""
total_tokens = 0 total_tokens = 0
@@ -418,7 +308,7 @@ class AppChatService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件 # 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info( logger.info(
@@ -437,7 +327,7 @@ class AppChatService:
except Exception as e: except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True) logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件 # 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
async def multi_agent_chat( async def multi_agent_chat(
self, self,
@@ -524,8 +414,6 @@ class AppChatService:
"""多 Agent 聊天(流式)""" """多 Agent 聊天(流式)"""
start_time = time.time() start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None: if variables is None:
variables = {} variables = {}
@@ -631,7 +519,6 @@ class AppChatService:
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
workflow_service = WorkflowService(self.db)
payload = DraftRunRequest( payload = DraftRunRequest(
message=message, message=message,
variables=variables, variables=variables,
@@ -639,7 +526,7 @@ class AppChatService:
stream=True, stream=True,
user_id=user_id user_id=user_id
) )
return await workflow_service.run( return await self.workflow_service.run(
app_id=app_id, app_id=app_id,
payload=payload, payload=payload,
config=config, config=config,
@@ -666,7 +553,6 @@ class AppChatService:
) -> AsyncGenerator[dict, None]: ) -> AsyncGenerator[dict, None]:
"""聊天(流式)""" """聊天(流式)"""
workflow_service = WorkflowService(self.db)
payload = DraftRunRequest( payload = DraftRunRequest(
message=message, message=message,
variables=variables, variables=variables,
@@ -675,7 +561,7 @@ class AppChatService:
user_id=user_id, user_id=user_id,
files=files files=files
) )
async for event in workflow_service.run_stream( async for event in self.workflow_service.run_stream(
app_id=app_id, app_id=app_id,
payload=payload, payload=payload,
config=config, config=config,

View File

@@ -1791,372 +1791,6 @@ class AppService:
return shares return shares
# ==================== 试运行功能 ====================
async def draft_run(
self,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""试运行 Agent使用当前草稿配置
Args:
app_id: 应用ID
message: 用户消息
conversation_id: 会话ID用于多轮对话
user_id: 用户ID用于会话管理
variables: 自定义变量参数值
workspace_id: 工作空间ID用于权限验证
Returns:
Dict: 包含 AI 回复和元数据的字典
Raises:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或配置缺失时
"""
from app.services.draft_run_service import DraftRunService
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 4. 调用试运行服务
logger.debug(
"准备调用试运行服务",
extra={
"app_id": str(app_id),
"model": model_config.name,
"has_conversation_id": bool(conversation_id),
"has_variables": bool(variables)
}
)
draft_service = DraftRunService(self.db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables
)
logger.debug(
"试运行服务返回结果",
extra={
"result_type": str(type(result)),
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
"has_message": "message" in result if isinstance(result, dict) else False,
"has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False
}
)
logger.info(
"试运行完成",
extra={
"app_id": str(app_id),
"elapsed_time": result.get("elapsed_time"),
"model": model_config.name
}
)
return result
async def draft_run_stream(
self,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
):
"""试运行 Agent流式返回
Args:
app_id: 应用ID
message: 用户消息
conversation_id: 会话ID用于多轮对话
user_id: 用户ID用于会话管理
variables: 自定义变量参数值
workspace_id: 工作空间ID用于权限验证
Yields:
str: SSE 格式的事件数据
Raises:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或配置缺失时
"""
from app.services.draft_run_service import DraftRunService
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 4. 调用流式试运行服务
draft_service = DraftRunService(self.db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables
):
yield event
# ==================== 多模型对比试运行 ====================
async def draft_run_compare(
self,
*,
app_id: uuid.UUID,
message: str,
models: List[app_schema.ModelCompareItem],
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None,
parallel: bool = True,
timeout: int = 60
) -> Dict[str, Any]:
"""多模型对比试运行
Args:
app_id: 应用ID
message: 用户消息
models: 要对比的模型列表
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
workspace_id: 工作空间ID
parallel: 是否并行执行
timeout: 超时时间(秒)
Returns:
Dict: 对比结果
"""
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比试运行",
extra={
"app_id": str(app_id),
"model_count": len(models),
"parallel": parallel
}
)
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 准备所有模型配置
model_configs = []
for model_item in models:
model_config = self.db.get(ModelConfig, model_item.model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 合并参数agent配置参数 + 请求覆盖参数
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id
})
# 4. 调用 DraftRunService 的对比方法
draft_service = DraftRunService(self.db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
parallel=parallel,
timeout=timeout
)
logger.info(
"多模型对比完成",
extra={
"app_id": str(app_id),
"successful": result["successful_count"],
"failed": result["failed_count"]
}
)
return result
async def draft_run_compare_stream(
self,
*,
app_id: uuid.UUID,
message: str,
models: List[app_schema.ModelCompareItem],
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None,
parallel: bool = True,
timeout: int = 60
):
"""多模型对比试运行(流式返回)
Args:
app_id: 应用ID
message: 用户消息
models: 要对比的模型列表
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
workspace_id: 工作空间ID
timeout: 超时时间(秒)
Yields:
str: SSE 格式的事件数据
"""
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比流式试运行",
extra={
"app_id": str(app_id),
"model_count": len(models)
}
)
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 准备所有模型配置
model_configs = []
for model_item in models:
model_config = self.db.get(ModelConfig, model_item.model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 合并参数agent配置参数 + 请求覆盖参数
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id
})
# 4. 调用 DraftRunService 的流式对比方法
draft_service = DraftRunService(self.db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
parallel=parallel,
timeout=timeout
):
yield event
logger.info(
"多模型对比流式完成",
extra={"app_id": str(app_id)}
)
# ==================== 向后兼容的函数接口 ==================== # ==================== 向后兼容的函数接口 ====================
# 保留函数接口以兼容现有代码,但内部使用服务类 # 保留函数接口以兼容现有代码,但内部使用服务类
@@ -2278,53 +1912,6 @@ def get_apps_by_ids(
return service.get_apps_by_ids(app_ids, workspace_id) return service.get_apps_by_ids(app_ids, workspace_id)
# ==================== 向后兼容的函数接口 ====================
async def draft_run(
db: Session,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""试运行 Agent向后兼容接口"""
service = AppService(db)
return await service.draft_run(
app_id=app_id,
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
workspace_id=workspace_id
)
async def draft_run_stream(
db: Session,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
):
"""试运行 Agent 流式返回(向后兼容接口)"""
service = AppService(db)
async for event in service.draft_run_stream(
app_id=app_id,
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
workspace_id=workspace_id
):
yield event
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================
def get_app_service( def get_app_service(

View File

@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
@@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service from app.services import task_service
from app.services.conversation_service import ConversationService
from app.services.langchain_tool_server import Search from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
@@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel):
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None, def create_long_term_memory_tool(
user_rag_memory_id: Optional[str] = None): memory_config: Dict[str, Any],
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
):
"""创建记忆工具, """创建记忆工具,
@@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
memory_config: 记忆配置 memory_config: 记忆配置
end_user_id: 用户ID end_user_id: 用户ID
storage_type: 存储类型(可选) storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
Returns: Returns:
长期记忆工具 长期记忆工具
@@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args: Args:
query: 需要检索的问题或关键词 kb_config: 知识库配置
kb_ids: 知识库ID列表
user_id: 用户ID
Returns: Returns:
检索到的相关知识内容 检索到的相关知识内容
@@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
return knowledge_retrieval_tool return knowledge_retrieval_tool
class DraftRunService: class AgentRunService:
"""运行服务类""" """Agent运行服务类"""
def __init__(self, db: Session): def __init__(self, db: Session):
"""初始化试运行服务 """Agent运行服务
Args: Args:
db: 数据库会话 db: 数据库会话
""" """
self.db = db self.db = db
@staticmethod
def prepare_variables(
input_vars: dict | None,
variables_config: dict
) -> dict:
input_vars = input_vars or {}
for variable in variables_config:
if variable.get("required") and variable.get("name") not in input_vars:
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
return input_vars
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置"""
if not tools_config:
return []
tools = []
tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list):
for tool_config in tools_config:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif tools_config and isinstance(tools_config, dict):
web_search_choice = tools_config.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search and web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools
def load_skill_config(
self,
skills_config: dict | None,
message: str, tenant_id
) -> tuple[list, str]:
if not skills_config:
return [], ""
tools = []
skill_prompts = ""
skill_enable = skills_config.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills_config)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
skill_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
return tools, skill_prompts
def load_knowledge_retrieval_config(
self,
knowledge_retrieval_config: dict | None,
user_id
) -> list:
if not knowledge_retrieval_config:
return []
tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
return tools
def load_memory_config(
self,
memory_config: dict | None,
user_id,
storage_type,
user_rag_memory_id
) -> tuple[list, bool]:
"""加载长期记忆配置"""
if not memory_config:
return [], False
tools = []
if memory_config.get("enabled"):
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
return tools, bool(memory_config.get("enabled"))
async def run( async def run(
self, self,
*, *,
@@ -270,19 +403,21 @@ class DraftRunService:
conversation_id: 会话ID用于多轮对话 conversation_id: 会话ID用于多轮对话
user_id: 用户ID user_id: 用户ID
variables: 自定义变量参数值 variables: 自定义变量参数值
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
web_search: 是否启用网络搜索默认True
memory: 是否启用长期记忆默认True
sub_agent: 是否为子代理调用默认False
files: 多模态文件列表(可选)
Returns: Returns:
Dict: 包含 AI 回复和元数据的字典 Dict: 包含 AI 回复和元数据的字典
""" """
memory_flag = False
print('===========', storage_type)
print(user_id)
if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent
start_time = time.time() start_time = time.time()
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
try: try:
# 1. 获取 API Key 配置 # 1. 获取 API Key 配置
@@ -302,112 +437,40 @@ class DraftRunService:
agent_config=agent_config agent_config=agent_config
) )
items_params = variables if sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
system_prompt = render_prompt_message( system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误 agent_config.system_prompt,
PromptMessageRole.USER, PromptMessageRole.USER,
items_params variables
) )
# 3. 处理系统提示词(支持变量替换) # 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
print('系统提示词:', system_prompt)
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
if hasattr(agent_config, 'tools') and agent_config.tools: skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
for tool_config in agent_config.tools:
print("+" * 50)
print(f"agent_config:{agent_config}")
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools) tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具") if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
# 应用动态过滤 tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False
if memory: if memory:
if agent_config.memory and agent_config.memory.get("enabled"): memory_tools, memory_flag = self.load_memory_config(
memory_flag = True memory_config, user_id, storage_type, user_rag_memory_id
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
) )
tools.extend(memory_tools)
# 4. 创建 LangChain Agent # 4. 创建 LangChain Agent
agent = LangChainAgent( agent = LangChainAgent(
@@ -432,7 +495,7 @@ class DraftRunService:
# 6. 加载历史消息 # 6. 加载历史消息
history = [] history = []
if agent_config.memory and agent_config.memory.get("enabled"): if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history( history = await self._load_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10) max_history=agent_config.memory.get("max_history", 10)
@@ -482,7 +545,7 @@ class DraftRunService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息 # 9. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): if not sub_agent and memory_config and memory_config.get("enabled"):
await self._save_conversation_message( await self._save_conversation_message(
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=message, user_message=message,
@@ -557,16 +620,21 @@ class DraftRunService:
Yields: Yields:
str: SSE 格式的事件数据 str: SSE 格式的事件数据
""" """
memory_flag = False tools_config: dict | list | None = agent_config.tools
if variables == None: variables = {} skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
from app.core.agent.langchain_agent import LangChainAgent memory_config: dict | None = agent_config.memory
start_time = time.time() start_time = time.time()
try: try:
# 1. 获取 API Key 配置 # 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id) api_key_config = await self._get_api_key(model_config.id)
if not sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
# 2. 合并模型参数 # 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters( effective_params = ModelParameterMerger.get_effective_parameters(
@@ -588,95 +656,22 @@ class DraftRunService:
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
for tool_config in agent_config.tools: skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
# print("+"*50)
# print(f"agent_config:{agent_config}")
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools) tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具") if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False
if memory: if memory:
if agent_config.memory and agent_config.memory.get("enabled"): memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
memory_flag = True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id) user_rag_memory_id)
tools.append(memory_tool) tools.extend(memory_tools)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
# 4. 创建 LangChain Agent # 4. 创建 LangChain Agent
agent = LangChainAgent( agent = LangChainAgent(
@@ -702,10 +697,10 @@ class DraftRunService:
# 6. 加载历史消息 # 6. 加载历史消息
history = [] history = []
if agent_config.memory and agent_config.memory.get("enabled"): if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history( history = await self._load_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10) max_history=memory_config.get("max_history", 10)
) )
# 6. 处理多模态文件 # 6. 处理多模态文件
@@ -763,7 +758,7 @@ class DraftRunService:
}) })
# 10. 保存会话消息 # 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): if not sub_agent and memory_config and memory_config.get("enabled"):
await self._save_conversation_message( await self._save_conversation_message(
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=message, user_message=message,
@@ -969,7 +964,6 @@ class DraftRunService:
List[Dict]: 历史消息列表 List[Dict]: 历史消息列表
""" """
try: try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db) conversation_service = ConversationService(self.db)
history = conversation_service.get_conversation_history( history = conversation_service.get_conversation_history(
@@ -1489,6 +1483,15 @@ class DraftRunService:
"conversation_id": returned_conversation_id, "conversation_id": returned_conversation_id,
"content": chunk "content": chunk
})) }))
if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"error": event_data.get("error", "未知错误")
}))
except Exception as e: except Exception as e:
logger.warning(f"解析流式事件失败: {e}") logger.warning(f"解析流式事件失败: {e}")
finally: finally:
@@ -1673,41 +1676,3 @@ class DraftRunService:
"total_time": sum(r.get("elapsed_time", 0) for r in results) "total_time": sum(r.get("elapsed_time", 0) for r in results)
} }
) )
async def draft_run(
db: Session,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
user_id: Optional[str] = None,
kb_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.7,
top_k: int = 3
) -> Dict[str, Any]:
"""试运行 Agent便捷函数
Args:
db: 数据库会话
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
user_id: 用户ID
kb_ids: 知识库ID列表
similarity_threshold: 相似度阈值
top_k: 检索返回的文档数量
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
service = DraftRunService(db)
return await service.run(
agent_config=agent_config,
model_config=model_config,
message=message,
user_id=user_id,
kb_ids=kb_ids,
similarity_threshold=similarity_threshold,
top_k=top_k
)

View File

@@ -9,6 +9,8 @@ load_dotenv()
# 读取web_search环境变量 # 读取web_search环境变量
web_search_value = os.getenv('web_search') web_search_value = os.getenv('web_search')
def Search(query): def Search(query):
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
api_key = web_search_value api_key = web_search_value
@@ -18,23 +20,24 @@ def Search(query):
"role": "user", "role": "user",
"content": query "content": query
} }
], #搜索输入 ], # 搜索输入
"edition":"standard", #搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。 "edition": "standard", # 搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。
"search_source": "baidu_search_v2", #使用的搜索引擎版本 "search_source": "baidu_search_v2", # 使用的搜索引擎版本
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5 "resource_type_filter": [{"type": "web", "top_k": 20}],
# 支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5
"search_filter": { "search_filter": {
"range": { "range": {
"page_time": { "page_time": {
"gte": "now-1w/d", #时间查询参数,大于或等于 "gte": "now-1w/d", # 时间查询参数,大于或等于
"lt": "now/d", #时间查询参数,小于 "lt": "now/d", # 时间查询参数,小于
"gt": "", #时间查询参数,大于 "gt": "", # 时间查询参数,大于
"lte": "" #时间查询参数,小于或等于 "lte": "" # 时间查询参数,小于或等于
} }
} }
}, },
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表 "block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表
"search_recency_filter":"week", #根据网页发布时间进行筛选可填值为week,month,semiyear,year "search_recency_filter": "week", # 根据网页发布时间进行筛选可填值为week,month,semiyear,year
"enable_full_content":True #是否输出网页完整原文 "enable_full_content": True # 是否输出网页完整原文
}, ensure_ascii=False) }, ensure_ascii=False)
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -42,10 +45,10 @@ def Search(query):
} }
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json() response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
content=[] content = []
for i in response['references']: for i in response['references']:
title=i['title'] title = i['title']
snippet=i['snippet'] snippet = i['snippet']
content.append(title+';'+snippet) content.append(title + ';' + snippet)
content=''.join(content) content = ''.join(content)
return content return content

View File

@@ -123,11 +123,14 @@ class MultiAgentOrchestrator:
user_id: 用户 ID user_id: 用户 ID
variables: 变量参数 variables: 变量参数
use_llm_routing: 是否使用 LLM 路由 use_llm_routing: 是否使用 LLM 路由
web_search: 是否启用网络搜索
memory: 是否启用记忆功能
storage_type: 存储类型
user_rag_memory_id: 用户 RAG 记忆 ID
Yields: Yields:
SSE 格式的事件流 SSE 格式的事件流
""" """
import json
start_time = time.time() start_time = time.time()
@@ -200,7 +203,8 @@ class MultiAgentOrchestrator:
except Exception as e: except Exception as e:
logger.error( logger.error(
"多 Agent 任务执行失败(流式)", "多 Agent 任务执行失败(流式)",
extra={"error": str(e), "mode": self._normalized_mode} extra={"error": str(e), "mode": self._normalized_mode},
exc_info=True
) )
# 发送错误事件 # 发送错误事件
yield self._format_sse_event("error", { yield self._format_sse_event("error", {
@@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator:
Yields: Yields:
SSE 格式的事件流 SSE 格式的事件流
""" """
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
# 获取模型配置 # 获取模型配置
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
@@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator:
) )
# 流式执行 Agent # 流式执行 Agent
draft_service = DraftRunService(self.db) draft_service = AgentRunService(self.db)
async for event in draft_service.run_stream( async for event in draft_service.run_stream(
agent_config=agent_config, agent_config=agent_config,
model_config=model_config, model_config=model_config,
@@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator:
Returns: Returns:
执行结果 执行结果
""" """
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import AgentRunService
# 获取模型配置 # 获取模型配置
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
@@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator:
) )
# 执行 Agent # 执行 Agent
draft_service = DraftRunService(self.db) draft_service = AgentRunService(self.db)
result = await draft_service.run( result = await draft_service.run(
agent_config=agent_config, agent_config=agent_config,
model_config=model_config, model_config=model_config,
@@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator:
self.memory = config_data.get("memory") self.memory = config_data.get("memory")
self.variables = config_data.get("variables", []) self.variables = config_data.get("variables", [])
self.tools = config_data.get("tools", {}) self.tools = config_data.get("tools", {})
self.skills = config_data.get("skills", {})
self.default_model_config_id = release.default_model_config_id self.default_model_config_id = release.default_model_config_id
return AgentConfigProxy(release, app, config_data) return AgentConfigProxy(release, app, config_data)

View File

@@ -121,7 +121,7 @@ class SkillService:
if skill and skill.is_active: if skill and skill.is_active:
# 加载技能关联的工具 # 加载技能关联的工具
for tool_config in skill.tools: for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool: if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool) tools.append(langchain_tool)

View File

@@ -209,7 +209,7 @@ class ToolService:
try: try:
# 获取工具实例 # 获取工具实例
tool = self._get_tool_instance(tool_id, tenant_id) tool = self.get_tool_instance(tool_id, tenant_id)
if not tool: if not tool:
return ToolResult.error_result( return ToolResult.error_result(
error=f"工具不存在: {tool_id}", error=f"工具不存在: {tool_id}",
@@ -335,7 +335,7 @@ class ToolService:
return [] return []
# 获取工具实例 # 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance: if not tool_instance:
return [] return []
@@ -792,7 +792,7 @@ class ToolService:
"""获取工具配置""" """获取工具配置"""
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
"""获取工具实例""" """获取工具实例"""
if tool_id in self._tool_cache: if tool_id in self._tool_cache:
return self._tool_cache[tool_id] return self._tool_cache[tool_id]
@@ -1416,7 +1416,7 @@ class ToolService:
"""测试内置工具连接""" """测试内置工具连接"""
try: try:
# 获取工具实例 # 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance: if not tool_instance:
return {"success": False, "message": "无法创建工具实例"} return {"success": False, "message": "无法创建工具实例"}

View File

@@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db from app.db import get_db
from app.models import App from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -453,11 +454,14 @@ class WorkflowService:
files_struct = [] files_struct = []
for file in files: for file in files:
files_struct.append( files_struct.append(
{ FileObject(
"type": file.type, type=file.type,
"url": await self.multimodal_service.get_file_url(file), url=await self.multimodal_service.get_file_url(file),
"__file": True transfer_method=file.transfer_method,
} file_id=str(file.upload_file_id),
origin_file_type=file.file_type,
is_file=True
).model_dump()
) )
return files_struct return files_struct