Merge branch 'develop' into feature/20251219_myh

# Conflicts:
#	api/app/core/workflow/executor.py
#	api/app/core/workflow/nodes/node_factory.py
#	api/app/core/workflow/nodes/question_classifier/node.py
This commit is contained in:
mengyonghao
2026-01-05 11:10:01 +08:00
284 changed files with 21282 additions and 2779 deletions

View File

@@ -33,6 +33,7 @@ from . import (
emotion_config_controller,
prompt_optimizer_controller,
tool_controller,
home_page_controller,
)
from . import user_memory_controllers
@@ -70,5 +71,6 @@ manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(home_page_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,29 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.home_page_service import HomePageService
router = APIRouter(prefix="/home-page", tags=["Home Page"])
@router.get("/statistics", response_model=ApiResponse)
def get_home_statistics(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取首页统计数据"""
statistics = HomePageService.get_home_statistics(db, current_user.tenant_id)
return success(data=statistics, msg="统计数据获取成功")
@router.get("/workspaces", response_model=ApiResponse)
def get_workspace_list(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功")

View File

@@ -1,26 +1,28 @@
from typing import Optional
import datetime
import json
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.rag.common import settings
from app.core.rag.llm.chat_model import Base
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import knowledge_model, document_model, file_model
from app.schemas import knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service
from app.core.rag.llm.chat_model import Base
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.common import settings
from app.celery_app import celery_app
from app.services.model_service import ModelConfigService
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -47,6 +49,45 @@ def get_parser_types():
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
async def get_knowledge_graph_entity_types(
llm_id: uuid.UUID,
scenario: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
get knowledge graph entity types based on llm_id
"""
api_logger.info(f"Obtain details of the knowledge graph: llm_id={llm_id}, username: {current_user.username}")
try:
# 1. Check whether the model exists
api_logger.debug(f"Check whether the model exists: {llm_id}")
config = ModelConfigService.get_model_by_id(db=db, model_id=llm_id)
if not config:
api_logger.warning(
f"The model does not exist or you do not have permission to access it: llm_id={llm_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The model does not exist or you do not have permission to access it"
)
# 2. Prepare to configure chat_mdl information
chat_model = Base(
key=config.api_keys[0].api_key,
model_name=config.api_keys[0].model_name,
base_url=config.api_keys[0].api_base
)
response = graph_entity_types(chat_model, scenario)
return success(data=response, msg="Successfully obtained knowledge graph entity types")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"get knowledge graph entity types failed: llm_id={llm_id} - {str(e)}")
raise
@router.get("/knowledges", response_model=ApiResponse)
async def get_knowledges(
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
@@ -379,7 +420,7 @@ async def delete_knowledge_graph(
current_user: User = Depends(get_current_user)
):
"""
Soft-delete knowledge graph
delete knowledge graph
"""
api_logger.info(f"Request to delete knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
@@ -442,42 +483,3 @@ async def rebuild_knowledge_graph(
except Exception as e:
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/{knowledge_id}/knowledge_graph_entity_types", response_model=ApiResponse)
async def get_knowledge_graph_entity_types(
knowledge_id: uuid.UUID,
scenario: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
get knowledge graph entity types based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(
f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. Prepare to configure chat_mdl information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
response = graph_entity_types(chat_model, scenario)
return success(data=response, msg="Successfully obtained knowledge graph entity types")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"get knowledge graph entity types failed: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -1,6 +1,6 @@
import hashlib
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
@@ -17,6 +17,8 @@ from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -265,7 +267,8 @@ def get_conversation(
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
):
"""发送消息并获取回复
@@ -285,7 +288,7 @@ async def chat(
password = None # Token 认证不需要密码
# end_user_id = user_id
other_id = user_id
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
@@ -307,7 +310,7 @@ async def chat(
other_id=other_id,
original_user_id=user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
appid=share.app_id
"""获取存储类型和工作空间的ID"""
@@ -390,15 +393,38 @@ async def chat(
if app_type == AppType.AGENT:
# 流式返回
if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator():
async for event in service.chat_stream(
share_token=share_token,
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
config=payload.agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
@@ -414,32 +440,43 @@ async def chat(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
result = await service.chat(
share_token=share_token,
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
config= payload.agent_config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
config = multi_agent_config_4_app_release(release)
if payload.stream:
async def event_generator():
async for event in service.multi_agent_chat_stream(
share_token=share_token,
async for event in app_chat_service.multi_agent_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
@@ -458,20 +495,62 @@ async def chat(
)
# 多 Agent 非流式返回
result = await service.multi_agent_chat(
share_token=share_token,
result = await app_chat_service.multi_agent_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
password=password,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode

View File

@@ -21,7 +21,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.conversation_service import ConversationService, get_conversation_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_service import get_app_service, AppService
router = APIRouter(prefix="/app", tags=["V1 - App API"])
@@ -137,10 +137,10 @@ async def chat(
if app_type == AppType.AGENT:
print("="*50)
print(app.current_release.default_model_config_id)
# print("="*50)
# print(app.current_release.default_model_config_id)
agent_config = agent_config_4_app_release(app.current_release)
print(agent_config.default_model_config_id)
# print(agent_config.default_model_config_id)
# 流式返回
if payload.stream:
async def event_generator():
@@ -182,7 +182,7 @@ async def chat(
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
config = dict_to_multi_agent_config(app.current_release.config,app.id)
config = multi_agent_config_4_app_release(app.current_release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(

View File

@@ -60,6 +60,22 @@ async def list_tools(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}/methods", response_model=ApiResponse)
async def get_tool_methods(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具的所有方法"""
try:
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
if methods is None:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=methods, msg="获取工具方法成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}", response_model=ApiResponse)
async def get_tool(
tool_id: str,
@@ -159,7 +175,8 @@ async def execute_tool(
workspace_id=current_user.current_workspace_id,
timeout=request.timeout
)
if not result.success:
raise HTTPException(status_code=400, detail=result["error"])
return success(
data={
"success": result.success,

View File

@@ -3,7 +3,7 @@ import secrets
from typing import Optional, Union
from datetime import datetime
from app.schemas.api_key_schema import ApiKeyType
from app.models.api_key_model import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse

View File

@@ -48,7 +48,6 @@ class RAGExcelParser:
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
file_like_object.seek(0)
df = pd.read_excel(file_like_object, engine="calamine")
print("lxc1")
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_pandas:
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
@@ -215,19 +214,35 @@ class RAGExcelParser:
continue
if not rows:
continue
# 获取表头
ti = list(rows[0])
for r in list(rows[1:]):
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
header_fields = []
for cell in ti:
if cell.value: # 只添加有值的表头
header_fields.append(str(cell.value))
# 如果有数据行,处理数据行;否则只处理表头
data_rows = rows[1:]
if data_rows:
for r in data_rows:
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
else:
# 只有表头的情况
if header_fields:
line = "; ".join(header_fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
return res
@staticmethod

View File

@@ -1,7 +1,7 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .langchain_adapter import LangchainAdapter
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
from app.core.tools.langchain_adapter import LangchainAdapter
# 可选导入,避免导入错误
try:

View File

@@ -193,7 +193,7 @@ class BaseTool(ABC):
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):

View File

@@ -1,11 +1,11 @@
"""内置工具模块"""
from .base import BuiltinTool
from .datetime_tool import DateTimeTool
from .json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool
from .textin_tool import TextInTool
from app.core.tools.builtin.base import BuiltinTool
from app.core.tools.builtin.datetime_tool import DateTimeTool
from app.core.tools.builtin.json_tool import JsonTool
from app.core.tools.builtin.baidu_search_tool import BaiduSearchTool
from app.core.tools.builtin.mineru_tool import MinerUTool
from app.core.tools.builtin.textin_tool import TextInTool
__all__ = [
"BuiltinTool",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class BaiduSearchTool(BuiltinTool):

View File

@@ -5,7 +5,7 @@ from typing import List
import pytz
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class DateTimeTool(BuiltinTool):
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
),
ToolParameter(
name="input_value",

View File

@@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET
from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class JsonTool(BuiltinTool):
@@ -29,8 +29,7 @@ class JsonTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge",
"extract", "insert", "replace", "delete", "parse"]
enum=["insert", "replace", "delete", "parse"]
),
ToolParameter(
name="input_data",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class MinerUTool(BuiltinTool):

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class TextInTool(BuiltinTool):

View File

@@ -1,8 +1,8 @@
"""自定义工具模块"""
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
from app.core.tools.custom.base import CustomTool
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
from app.core.tools.custom.auth_manager import AuthManager
__all__ = [
"CustomTool",

View File

@@ -1,8 +1,8 @@
"""MCP工具模块"""
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
from app.core.tools.mcp.base import MCPTool
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.service_manager import MCPServiceManager
__all__ = [
"MCPTool",

View File

@@ -1,7 +1,6 @@
"""MCP工具基类"""
import time
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool

View File

@@ -204,7 +204,7 @@ class MCPClient:
)
init_response = json.loads(response)
if "error" in init_response:
if init_response.get("error", None) is not None:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
@@ -325,7 +325,7 @@ class MCPClient:
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
logger = get_business_logger()

View File

@@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
__all__ = [
"BaseNode",
@@ -33,5 +35,7 @@ __all__ = [
"AssignerNode",
"HttpRequestNode",
"JinjaRenderNode",
"ParameterExtractorNode"
"ParameterExtractorNode",
"QuestionClassifierNode",
"ToolNode"
]

View File

@@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
@@ -45,4 +46,5 @@ __all__ = [
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig"
"ToolNodeConfig"
]

View File

@@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
logger = logging.getLogger(__name__)
@@ -44,7 +45,8 @@ WorkflowNode = Union[
CycleGraphNode,
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode
QuestionClassifierNode,
ToolNode
]
@@ -73,6 +75,7 @@ class NodeFactory:
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode,
NodeType.TOOL: ToolNode,
}
@classmethod

View File

@@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig):
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
description="用户提示词模板"
)
output_variable: str = Field(default="class_name", description="输出分类结果的变量名")

View File

@@ -12,32 +12,36 @@ from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
DEFAULT_CASE_PREFIX = "CASE"
DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode):
"""问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
base_url = api_config.api_base
model_type = config.type
return RedBearLLM(
RedBearModelConfig(
model_name=model_name,
@@ -48,47 +52,72 @@ class QuestionClassifierNode(BaseNode):
type=ModelType(model_type)
)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
def _build_category_case_map(self) -> dict[str, str]:
"""
预构建 分类名称 -> CASE标识 的映射字典
示例:{"产品咨询": "CASE1", "售后问题": "CASE2"}
"""
category_map = {}
categories = self.typed_config.categories or []
for idx, class_item in enumerate(categories, start=1):
category_name = class_item.class_name.strip()
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> str:
"""执行问题分类"""
question = self.typed_config.input_variable
supplement_prompt = ""
if self.typed_config.user_supplement_prompt is not None:
supplement_prompt = self.typed_config.user_supplement_prompt
category_names = [class_item.class_name for class_item in self.typed_config.categories]
supplement_prompt = self.typed_config.user_supplement_prompt or ""
categories = self.typed_config.categories or []
category_names = [class_item.class_name.strip() for class_item in categories]
category_count = len(category_names)
if not question:
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
logger.warning(
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
llm = self._get_llm_instance()
try:
llm = self._get_llm_instance()
# 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template(
self.typed_config.user_prompt.format(
question=question,
categories=", ".join(category_names),
supplement_prompt=supplement_prompt
),
state
)
# 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template(
self.typed_config.user_prompt.format(
question=question,
categories=", ".join(category_names),
supplement_prompt=supplement_prompt
),
state
)
messages = [
("system", self.typed_config.system_prompt),
("user", user_prompt),
]
messages = [
("system", self.typed_config.system_prompt),
("user", user_prompt),
]
response = await llm.ainvoke(messages)
result = response.content.strip()
response = await llm.ainvoke(messages)
result = response.content.strip()
if result in category_names:
category = result
else:
logger.warning(f"LLM返回了未知类别: {result}")
category = category_names[0] if category_names else "unknown"
if result in category_names:
category = result
else:
logger.warning(f"LLM返回了未知类别: {result}")
category = category_names[0] if category_names else "unknown"
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return {self.typed_config.output_variable: category}
return f"CASE{category_names.index(category) + 1}"
except Exception as e:
logger.error(
f"节点 {self.node_id} 分类执行异常:{str(e)}",
exc_info=True # 打印堆栈信息,便于调试
)
# 异常时返回默认分支,保证工作流容错性
if category_count > 0:
return DEFAULT_EMPTY_QUESTION_CASE
return "unknown"

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.tool.node import ToolNode
__all__ = ["ToolNode", "ToolNodeConfig"]

View File

@@ -0,0 +1,9 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -0,0 +1,72 @@
import logging
import uuid
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.services.tool_service import ToolService
from app.db import get_db_read
logger = logging.getLogger(__name__)
class ToolNode(BaseNode):
"""工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ToolNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
tool_service = ToolService(db)
result = await tool_service.execute_tool(
tool_id=self.typed_config.tool_id,
parameters=rendered_parameters,
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
"success": True,
"data": result.data,
"execution_time": result.execution_time
}
else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}

View File

@@ -0,0 +1,137 @@
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from uuid import UUID
from typing import Dict
from app.models.end_user_model import EndUser
from app.models.user_model import User
from app.models.workspace_model import Workspace, WorkspaceMember
from app.models.models_model import ModelConfig
from app.models.app_model import App
class HomePageRepository:
@staticmethod
def get_model_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]:
"""获取模型统计数据"""
total_models = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True
).count()
new_models_this_month = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True,
ModelConfig.created_at >= month_start
).count()
return total_models, new_models_this_month
@staticmethod
def get_workspace_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]:
"""获取工作空间统计数据"""
active_workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).count()
new_workspaces_this_month = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True,
Workspace.created_at >= month_start
).count()
return active_workspaces, new_workspaces_this_month
@staticmethod
def get_user_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]:
"""获取用户统计数据"""
workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).subquery()
total_users = db.query(EndUser).join(
App,
EndUser.app_id == App.id
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active"
).count()
new_users_this_month = db.query(EndUser).join(
App,
EndUser.app_id == App.id
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active",
EndUser.created_at >= month_start
).count()
return total_users, new_users_this_month
@staticmethod
def get_app_statistics(db: Session, tenant_id: UUID, week_start: datetime) -> tuple[int, int]:
"""获取应用统计数据"""
workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).subquery()
running_apps = db.query(App).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active"
).count()
new_apps_this_week = db.query(App).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active",
App.created_at >= week_start
).count()
return running_apps, new_apps_this_week
@staticmethod
def get_workspaces_with_counts(db: Session, tenant_id: UUID) -> tuple[list[Workspace], Dict[UUID, int], Dict[UUID, int]]:
"""批量获取工作空间及其统计数据"""
# 获取工作空间列表
workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).all()
workspace_ids = [ws.id for ws in workspaces]
# 批量获取应用数量
app_counts = db.query(
App.workspace_id,
func.count(App.id).label('count')
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active,
App.status == "active"
).group_by(App.workspace_id).all()
app_count_dict = {workspace_id: count for workspace_id, count in app_counts}
# 批量获取用户数量
user_counts = db.query(
App.workspace_id,
func.count(EndUser.id).label('count')
).join(
EndUser,
EndUser.app_id == App.id
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active,
App.status == "active"
).group_by(App.workspace_id).all()
user_count_dict = {workspace_id: count for workspace_id, count in user_counts}
return workspaces, app_count_dict, user_count_dict

View File

@@ -1,10 +1,9 @@
"""工具数据访问层"""
import uuid
from typing import List, Optional, Dict, Any
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from sqlalchemy import func
from app.repositories.base_repository import BaseRepository
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus
@@ -14,6 +13,31 @@ from app.models.tool_model import (
class ToolRepository:
"""工具仓储类"""
@staticmethod
def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]:
"""根据工作流ID获取tenant_id
Args:
db: 数据库会话
workflow_id: 工作流配置ID
Returns:
tenant_id或None
"""
from app.models.app_model import App
from app.models.workflow_model import WorkflowConfig
from app.models.workspace_model import Workspace
result = db.query(Workspace.tenant_id).join(
App, App.workspace_id == Workspace.id
).join(
WorkflowConfig, WorkflowConfig.app_id == App.id
).filter(
WorkflowConfig.id == workflow_id
).first()
return result[0] if result else None
@staticmethod
def find_by_tenant(
db: Session,

View File

@@ -0,0 +1,32 @@
from datetime import datetime
from pydantic import BaseModel, field_serializer
from typing import Optional
from app.core.api_key_utils import datetime_to_timestamp
class HomeStatistics(BaseModel):
"""首页统计数据"""
total_models: int
new_models_this_month: int
active_workspaces: int
new_workspaces_this_month: int
total_users: int
new_users_this_month: int
running_apps: int
new_apps_this_week: int
class WorkspaceInfo(BaseModel):
"""工作空间信息"""
id: str
name: str
icon: Optional[str]
description: Optional[str]
app_count: int
user_count: int
created_at: datetime
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v: datetime) -> Optional[int]:
return datetime_to_timestamp(v)

View File

@@ -1203,11 +1203,11 @@ class AppService:
self._check_multi_agent_config(app_id)
# 3. 获取主 Agent 的模型配置 ID
master_agent = self.db.get(AgentConfig, multi_agent_cfg.master_agent_id)
default_model_config_id = master_agent.default_model_config_id if master_agent else None
default_model_config_id = multi_agent_cfg.default_model_config_id
# 4. 构建配置快照
config = {
"model_parameters":multi_agent_cfg.model_parameters,
"master_agent_id": str(multi_agent_cfg.master_agent_id),
"orchestration_mode": multi_agent_cfg.orchestration_mode,
"sub_agents": multi_agent_cfg.sub_agents,
@@ -1220,7 +1220,7 @@ class AppService:
"多智能体应用发布配置准备完成",
extra={
"app_id": str(app_id),
"master_agent_id": str(multi_agent_cfg.master_agent_id),
"default_model_config_id": str(default_model_config_id),
"sub_agent_count": len(multi_agent_cfg.sub_agents) if multi_agent_cfg.sub_agents else 0,
"orchestration_mode": multi_agent_cfg.orchestration_mode
}

View File

@@ -0,0 +1,67 @@
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from uuid import UUID
from app.repositories.home_page_repository import HomePageRepository
from app.schemas.home_page_schema import HomeStatistics, WorkspaceInfo
class HomePageService:
@staticmethod
def get_home_statistics(db: Session, tenant_id: UUID) -> HomeStatistics:
"""获取首页统计数据"""
# 计算时间范围
now = datetime.now()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
week_start = now - timedelta(days=now.weekday())
week_start = week_start.replace(hour=0, minute=0, second=0, microsecond=0)
# 获取各项统计数据
total_models, new_models_this_month = HomePageRepository.get_model_statistics(
db, tenant_id, month_start
)
active_workspaces, new_workspaces_this_month = HomePageRepository.get_workspace_statistics(
db, tenant_id, month_start
)
total_users, new_users_this_month = HomePageRepository.get_user_statistics(
db, tenant_id, month_start
)
running_apps, new_apps_this_week = HomePageRepository.get_app_statistics(
db, tenant_id, week_start
)
return HomeStatistics(
total_models=total_models,
new_models_this_month=new_models_this_month,
active_workspaces=active_workspaces,
new_workspaces_this_month=new_workspaces_this_month,
total_users=total_users,
new_users_this_month=new_users_this_month,
running_apps=running_apps,
new_apps_this_week=new_apps_this_week
)
@staticmethod
def get_workspace_list(db: Session, tenant_id: UUID) -> list[WorkspaceInfo]:
"""获取工作空间列表(优化版本)"""
workspaces, app_count_dict, user_count_dict= HomePageRepository.get_workspaces_with_counts(
db, tenant_id
)
workspace_list = []
for workspace in workspaces:
workspace_info = WorkspaceInfo(
id=str(workspace.id),
name=workspace.name,
icon=workspace.icon,
description=workspace.description,
app_count=app_count_dict.get(workspace.id, 0),
user_count=user_count_dict.get(workspace.id, 0),
created_at=workspace.created_at
)
workspace_list.append(workspace_info)
return workspace_list

View File

@@ -890,13 +890,13 @@ class MultiAgentOrchestrator:
)
# 发送整合后的结果
yield self._format_sse_event("merge_complete", {
yield self._format_sse_event("message", {
"content": final_response
})
except Exception as e:
logger.error(f"Master Agent 整合失败,降级到 smart 模式: {str(e)}")
final_response = self._smart_merge_results(results, collaboration_strategy)
yield self._format_sse_event("merge_complete", {
yield self._format_sse_event("message", {
"content": final_response
})
else:
@@ -912,7 +912,7 @@ class MultiAgentOrchestrator:
# 只有在需要时才发送整合结果
if final_response and final_response != "":
yield self._format_sse_event("merge_complete", {
yield self._format_sse_event("message", {
"content": final_response
})

View File

@@ -297,6 +297,165 @@ class ToolService:
self.db.commit()
logger.info(f"租户 {tenant_id} 内置工具初始化完成")
async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]:
"""获取工具的所有方法
Args:
tool_id: 工具ID
tenant_id: 租户ID
Returns:
方法列表或None
"""
config = self._get_tool_config(tool_id, tenant_id)
if not config:
return None
try:
if config.tool_type == ToolType.BUILTIN.value:
return await self._get_builtin_tool_methods(config)
elif config.tool_type == ToolType.CUSTOM.value:
return await self._get_custom_tool_methods(config)
elif config.tool_type == ToolType.MCP.value:
return await self._get_mcp_tool_methods(config)
else:
return []
except Exception as e:
logger.error(f"获取工具方法失败: {tool_id}, {e}")
return []
async def _get_builtin_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取内置工具的方法"""
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
return []
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return []
# 检查是否有operation参数
operation_param = None
for param in tool_instance.parameters:
if param.name == "operation" and param.enum:
operation_param = param
break
if operation_param:
# 有多个操作
methods = []
for operation in operation_param.enum:
methods.append({
"method_id": f"{config.name}_{operation}",
"name": operation,
"description": f"{config.description} - {operation}",
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
})
return methods
else:
# 只有一个方法
return [{
"method_id": config.name,
"name": config.name,
"description": config.description,
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
}]
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取自定义工具的方法"""
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
if not custom_config:
return []
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
parser = OpenAPISchemaParser()
# 解析schema
if custom_config.schema_content:
success, schema, error = parser.parse_from_content(custom_config.schema_content, "application/json")
elif custom_config.schema_url:
success, schema, error = await parser.parse_from_url(custom_config.schema_url)
else:
return []
if not success:
return []
# 提取操作
tool_info = parser.extract_tool_info(schema)
operations = tool_info.get("operations", {})
methods = []
for operation_id, operation in operations.items():
# 生成参数列表
parameters = []
# 路径和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
parameters.append({
"name": param_name,
"type": param_info.get("type", "string"),
"description": param_info.get("description", ""),
"required": param_info.get("required", False),
"enum": param_info.get("enum"),
"default": param_info.get("default")
})
# 请求体参数
request_body = operation.get("request_body")
if request_body:
schema_props = request_body.get("schema", {}).get("properties", {})
required_props = request_body.get("schema", {}).get("required", [])
for prop_name, prop_schema in schema_props.items():
parameters.append({
"name": prop_name,
"type": prop_schema.get("type", "string"),
"description": prop_schema.get("description", ""),
"required": prop_name in required_props,
"enum": prop_schema.get("enum"),
"default": prop_schema.get("default")
})
methods.append({
"method_id": operation_id,
"name": operation.get("summary", operation_id),
"description": operation.get("description", ""),
"method": operation.get("method", "GET"),
"path": operation.get("path", "/"),
"parameters": parameters
})
return methods
except Exception as e:
logger.error(f"解析自定义工具schema失败: {e}")
return []
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取MCP工具的方法"""
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return []
available_tools = mcp_config.available_tools or []
if not available_tools:
return []
methods = []
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": [] # MCP工具参数需要动态获取
})
return methods
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
"""获取工具统计信息"""
try:

View File

@@ -9,7 +9,8 @@ from typing import Dict, Any, Optional
from datetime import datetime
from app.models import AppRelease
from app.models.agent_app_config_model import AgentConfig
from app.models.multi_agent_model import MultiAgentConfig
class AgentConfigProxy:
"""Proxy class for AgentConfig (legacy compatibility)"""
@@ -24,20 +25,10 @@ class AgentConfigProxy:
self.default_model_config_id = release.default_model_config_id
def agent_config_4_app_release(release: AppRelease ):
from app.models.agent_app_config_model import AgentConfig
# Create AgentConfig instance
# config = {
# "system_prompt": agent_cfg.system_prompt,
# "model_parameters": agent_cfg.model_parameters,
# "knowledge_retrieval": agent_cfg.knowledge_retrieval,
# "memory": agent_cfg.memory,
# "variables": agent_cfg.variables or [],
# "tools": agent_cfg.tools or {},
# }
#
config_dict = release.config
def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
config_dict = release.config
agent_config = AgentConfig(
app_id=release.app_id,
system_prompt=config_dict.get("system_prompt"),
@@ -51,6 +42,26 @@ def agent_config_4_app_release(release: AppRelease ):
return agent_config
def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
config_dict = release.config
agent_config = MultiAgentConfig(
app_id=release.app_id,
default_model_config_id=release.default_model_config_id,
model_parameters=config_dict.get("model_parameters"),
master_agent_id=config_dict.get("master_agent_id"),
master_agent_name=config_dict.get("master_agent_name"),
orchestration_mode=config_dict.get("orchestration_mode", "conditional"),
sub_agents=config_dict.get("sub_agents", []),
routing_rules=config_dict.get("routing_rules"),
execution_config=config_dict.get("execution_config", {}),
aggregation_strategy=config_dict.get("aggregation_strategy", "merge"),
)
return agent_config
def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None):
"""Convert dict to MultiAgentConfig model object