Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -802,7 +802,8 @@ async def draft_run_compare(
|
|||||||
web_search=True,
|
web_search=True,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60
|
timeout=payload.timeout or 60,
|
||||||
|
files=payload.files
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -906,14 +907,14 @@ def get_app_statistics(
|
|||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
stats_service = AppStatisticsService(db)
|
stats_service = AppStatisticsService(db)
|
||||||
|
|
||||||
result = stats_service.get_app_statistics(
|
result = stats_service.get_app_statistics(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date
|
end_date=end_date
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result)
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
@@ -940,11 +941,11 @@ def get_workspace_api_statistics(
|
|||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
stats_service = AppStatisticsService(db)
|
stats_service = AppStatisticsService(db)
|
||||||
|
|
||||||
result = stats_service.get_workspace_api_statistics(
|
result = stats_service.get_workspace_api_statistics(
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date
|
end_date=end_date
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result)
|
return success(data=result)
|
||||||
|
|||||||
@@ -104,14 +104,18 @@ async def start_workspace_reflection(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""启动工作空间中所有匹配应用的反思功能"""
|
"""启动工作空间中所有匹配应用的反思功能"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
reflection_service = MemoryReflectionService(db)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||||
|
|
||||||
service = WorkspaceAppService(db)
|
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||||
result = service.get_workspace_apps_detailed(workspace_id)
|
from app.db import get_db_context
|
||||||
|
with get_db_context() as query_db:
|
||||||
|
service = WorkspaceAppService(query_db)
|
||||||
|
result = service.get_workspace_apps_detailed(workspace_id)
|
||||||
|
|
||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
# 跳过没有配置的应用
|
# 跳过没有配置的应用
|
||||||
if not data['memory_configs']:
|
if not data['memory_configs']:
|
||||||
@@ -133,33 +137,36 @@ async def start_workspace_reflection(
|
|||||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 为每个用户执行反思
|
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||||
for user in end_users:
|
for user in end_users:
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||||
|
|
||||||
try:
|
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||||
reflection_result = await reflection_service.start_text_reflection(
|
with get_db_context() as user_db:
|
||||||
config_data=config,
|
try:
|
||||||
end_user_id=user['id']
|
reflection_service = MemoryReflectionService(user_db)
|
||||||
)
|
reflection_result = await reflection_service.start_text_reflection(
|
||||||
|
config_data=config,
|
||||||
|
end_user_id=user['id']
|
||||||
|
)
|
||||||
|
|
||||||
reflection_results.append({
|
reflection_results.append({
|
||||||
"app_id": data['id'],
|
"app_id": data['id'],
|
||||||
"config_id": config_id_str,
|
"config_id": config_id_str,
|
||||||
"end_user_id": user['id'],
|
"end_user_id": user['id'],
|
||||||
"reflection_result": reflection_result
|
"reflection_result": reflection_result
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||||
reflection_results.append({
|
reflection_results.append({
|
||||||
"app_id": data['id'],
|
"app_id": data['id'],
|
||||||
"config_id": config_id_str,
|
"config_id": config_id_str,
|
||||||
"end_user_id": user['id'],
|
"end_user_id": user['id'],
|
||||||
"reflection_result": {
|
"reflection_result": {
|
||||||
"status": "错误",
|
"status": "错误",
|
||||||
"message": f"反思失败: {str(e)}"
|
"message": f"反思失败: {str(e)}"
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return success(data=reflection_results, msg="反思配置成功")
|
return success(data=reflection_results, msg="反思配置成功")
|
||||||
|
|
||||||
|
|||||||
@@ -462,8 +462,8 @@ class ReflectionEngine:
|
|||||||
List[Any]: 反思数据列表
|
List[Any]: 反思数据列表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
print("=== 获取反思数据 ===")
|
||||||
|
print(f" 主机ID: {host_id}")
|
||||||
if self.config.reflexion_range == ReflectionRange.PARTIAL:
|
if self.config.reflexion_range == ReflectionRange.PARTIAL:
|
||||||
neo4j_query = neo4j_query_part.format(host_id)
|
neo4j_query = neo4j_query_part.format(host_id)
|
||||||
neo4j_statement = neo4j_statement_part.format(host_id)
|
neo4j_statement = neo4j_statement_part.format(host_id)
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ class DraftRunCompareRequest(BaseModel):
|
|||||||
max_length=5,
|
max_length=5,
|
||||||
description="要对比的模型列表(1-5个)"
|
description="要对比的模型列表(1-5个)"
|
||||||
)
|
)
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
parallel: bool = Field(True, description="是否并行执行")
|
parallel: bool = Field(True, description="是否并行执行")
|
||||||
stream: bool = Field(False, description="是否流式返回")
|
stream: bool = Field(False, description="是否流式返回")
|
||||||
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
||||||
|
|||||||
@@ -16,26 +16,26 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
|
from app.core.agent.agent_middleware import AgentMiddleware
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
from app.services import task_service
|
from app.services import task_service
|
||||||
from app.services.langchain_tool_server import Search
|
from app.services.langchain_tool_server import Search
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.tool_service import ToolService
|
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.core.agent.agent_middleware import AgentMiddleware
|
from app.services.tool_service import ToolService
|
||||||
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalInput(BaseModel):
|
class KnowledgeRetrievalInput(BaseModel):
|
||||||
"""知识库检索工具输入参数"""
|
"""知识库检索工具输入参数"""
|
||||||
query: str = Field(description="需要检索的问题或关键词")
|
query: str = Field(description="需要检索的问题或关键词")
|
||||||
@@ -48,9 +48,12 @@ class WebSearchInput(BaseModel):
|
|||||||
|
|
||||||
class LongTermMemoryInput(BaseModel):
|
class LongTermMemoryInput(BaseModel):
|
||||||
"""长期记忆工具输入参数"""
|
"""长期记忆工具输入参数"""
|
||||||
question: str = Field(description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
question: str = Field(
|
||||||
|
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||||
|
|
||||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,user_rag_memory_id: Optional[str] = None):
|
|
||||||
|
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None):
|
||||||
"""创建记忆工具,
|
"""创建记忆工具,
|
||||||
|
|
||||||
|
|
||||||
@@ -66,6 +69,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||||
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
|
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
|
||||||
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||||||
|
|
||||||
@tool(args_schema=LongTermMemoryInput)
|
@tool(args_schema=LongTermMemoryInput)
|
||||||
def long_term_memory(question: str) -> str:
|
def long_term_memory(question: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -133,6 +137,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
return f"记忆检索失败: {str(e)}"
|
return f"记忆检索失败: {str(e)}"
|
||||||
|
|
||||||
return long_term_memory
|
return long_term_memory
|
||||||
|
|
||||||
|
|
||||||
@@ -179,7 +184,7 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
|
|||||||
return web_search_tool
|
return web_search_tool
|
||||||
|
|
||||||
|
|
||||||
def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -189,6 +194,7 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
|||||||
检索到的相关知识内容
|
检索到的相关知识内容
|
||||||
"""
|
"""
|
||||||
logger.info(f"创建知识库检索工具,用户:{user_id}")
|
logger.info(f"创建知识库检索工具,用户:{user_id}")
|
||||||
|
|
||||||
@tool(args_schema=KnowledgeRetrievalInput)
|
@tool(args_schema=KnowledgeRetrievalInput)
|
||||||
def knowledge_retrieval_tool(query: str) -> str:
|
def knowledge_retrieval_tool(query: str) -> str:
|
||||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||||
@@ -200,7 +206,6 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
|||||||
检索到的相关知识内容
|
检索到的相关知识内容
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
|
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
|
||||||
@@ -226,6 +231,7 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
|||||||
|
|
||||||
return knowledge_retrieval_tool
|
return knowledge_retrieval_tool
|
||||||
|
|
||||||
|
|
||||||
class DraftRunService:
|
class DraftRunService:
|
||||||
"""试运行服务类"""
|
"""试运行服务类"""
|
||||||
|
|
||||||
@@ -238,21 +244,21 @@ class DraftRunService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
message: str,
|
message: str,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True,
|
web_search: bool = True,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
sub_agent: bool = False,
|
sub_agent: bool = False,
|
||||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行试运行(使用 LangChain Agent)
|
"""执行试运行(使用 LangChain Agent)
|
||||||
|
|
||||||
@@ -268,9 +274,9 @@ class DraftRunService:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: 包含 AI 回复和元数据的字典
|
Dict: 包含 AI 回复和元数据的字典
|
||||||
"""
|
"""
|
||||||
memory_flag=False
|
memory_flag = False
|
||||||
|
|
||||||
print('===========',storage_type)
|
print('===========', storage_type)
|
||||||
|
|
||||||
print(user_id)
|
print(user_id)
|
||||||
if variables == None: variables = {}
|
if variables == None: variables = {}
|
||||||
@@ -296,8 +302,7 @@ class DraftRunService:
|
|||||||
agent_config=agent_config
|
agent_config=agent_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
items_params = variables
|
||||||
items_params=variables
|
|
||||||
system_prompt = render_prompt_message(
|
system_prompt = render_prompt_message(
|
||||||
agent_config.system_prompt, # 修正拼写错误
|
agent_config.system_prompt, # 修正拼写错误
|
||||||
PromptMessageRole.USER,
|
PromptMessageRole.USER,
|
||||||
@@ -306,7 +311,7 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 3. 处理系统提示词(支持变量替换)
|
# 3. 处理系统提示词(支持变量替换)
|
||||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||||
print('系统提示词:',system_prompt)
|
print('系统提示词:', system_prompt)
|
||||||
|
|
||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
@@ -318,7 +323,7 @@ class DraftRunService:
|
|||||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||||
for tool_config in agent_config.tools:
|
for tool_config in agent_config.tools:
|
||||||
print("+"*50)
|
print("+" * 50)
|
||||||
print(f"agent_config:{agent_config}")
|
print(f"agent_config:{agent_config}")
|
||||||
print(f"tool_config:{tool_config}")
|
print(f"tool_config:{tool_config}")
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
@@ -358,7 +363,8 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 应用动态过滤
|
# 应用动态过滤
|
||||||
if skill_configs:
|
if skill_configs:
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||||
|
tool_to_skill_map)
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
activated_skill_ids, skill_configs
|
activated_skill_ids, skill_configs
|
||||||
@@ -372,7 +378,7 @@ class DraftRunService:
|
|||||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||||
if kb_ids:
|
if kb_ids:
|
||||||
# 创建知识库检索工具
|
# 创建知识库检索工具
|
||||||
kb_tool = create_knowledge_retrieval_tool(kb_config,kb_ids,user_id)
|
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||||
tools.append(kb_tool)
|
tools.append(kb_tool)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -386,12 +392,13 @@ class DraftRunService:
|
|||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
if memory:
|
if memory:
|
||||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
memory_flag=True
|
memory_flag = True
|
||||||
|
|
||||||
memory_config = agent_config.memory
|
memory_config = agent_config.memory
|
||||||
if user_id:
|
if user_id:
|
||||||
# 创建长期记忆工具
|
# 创建长期记忆工具
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||||
|
user_rag_memory_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -452,7 +459,7 @@ class DraftRunService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_config_= agent_config.memory
|
memory_config_ = agent_config.memory
|
||||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||||
|
|
||||||
@@ -518,21 +525,21 @@ class DraftRunService:
|
|||||||
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
||||||
|
|
||||||
async def run_stream(
|
async def run_stream(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
message: str,
|
message: str,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True, # 布尔类型默认值
|
web_search: bool = True, # 布尔类型默认值
|
||||||
memory: bool = True, # 布尔类型默认值
|
memory: bool = True, # 布尔类型默认值
|
||||||
sub_agent: bool = False, # 是否是作为子Agent运行
|
sub_agent: bool = False, # 是否是作为子Agent运行
|
||||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行试运行(流式返回,使用 LangChain Agent)
|
"""执行试运行(流式返回,使用 LangChain Agent)
|
||||||
@@ -549,8 +556,8 @@ class DraftRunService:
|
|||||||
Yields:
|
Yields:
|
||||||
str: SSE 格式的事件数据
|
str: SSE 格式的事件数据
|
||||||
"""
|
"""
|
||||||
memory_flag=False
|
memory_flag = False
|
||||||
if variables==None:variables={}
|
if variables == None: variables = {}
|
||||||
|
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
|
|
||||||
@@ -566,7 +573,7 @@ class DraftRunService:
|
|||||||
agent_config=agent_config
|
agent_config=agent_config
|
||||||
)
|
)
|
||||||
|
|
||||||
items_params=variables
|
items_params = variables
|
||||||
|
|
||||||
system_prompt = render_prompt_message(
|
system_prompt = render_prompt_message(
|
||||||
agent_config.system_prompt, # 修正拼写错误
|
agent_config.system_prompt, # 修正拼写错误
|
||||||
@@ -626,14 +633,14 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 应用动态过滤
|
# 应用动态过滤
|
||||||
if skill_configs:
|
if skill_configs:
|
||||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||||
|
tool_to_skill_map)
|
||||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||||
active_prompts = AgentMiddleware.get_active_prompts(
|
active_prompts = AgentMiddleware.get_active_prompts(
|
||||||
activated_skill_ids, skill_configs
|
activated_skill_ids, skill_configs
|
||||||
)
|
)
|
||||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||||
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
kb_config = agent_config.knowledge_retrieval
|
kb_config = agent_config.knowledge_retrieval
|
||||||
@@ -654,11 +661,12 @@ class DraftRunService:
|
|||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
if memory:
|
if memory:
|
||||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
memory_flag= True
|
memory_flag = True
|
||||||
memory_config = agent_config.memory
|
memory_config = agent_config.memory
|
||||||
if user_id:
|
if user_id:
|
||||||
# 创建长期记忆工具
|
# 创建长期记忆工具
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||||
|
user_rag_memory_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -724,15 +732,15 @@ class DraftRunService:
|
|||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=context,
|
context=context,
|
||||||
end_user_id=user_id,
|
end_user_id=user_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag,
|
memory_flag=memory_flag,
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files # 传递处理后的文件
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
@@ -749,8 +757,8 @@ class DraftRunService:
|
|||||||
|
|
||||||
if sub_agent:
|
if sub_agent:
|
||||||
yield self._format_sse_event("sub_usage", {
|
yield self._format_sse_event("sub_usage", {
|
||||||
"total_tokens": total_tokens
|
"total_tokens": total_tokens
|
||||||
})
|
})
|
||||||
|
|
||||||
# 10. 保存会话消息
|
# 10. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
@@ -842,11 +850,11 @@ class DraftRunService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def _ensure_conversation(
|
async def _ensure_conversation(
|
||||||
self,
|
self,
|
||||||
conversation_id: Optional[str],
|
conversation_id: Optional[str],
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
user_id: Optional[str]
|
user_id: Optional[str]
|
||||||
) -> str:
|
) -> str:
|
||||||
"""确保会话存在(创建或验证)
|
"""确保会话存在(创建或验证)
|
||||||
|
|
||||||
@@ -863,7 +871,6 @@ class DraftRunService:
|
|||||||
BusinessException: 当指定的会话不存在时
|
BusinessException: 当指定的会话不存在时
|
||||||
"""
|
"""
|
||||||
from app.models import Conversation as ConversationModel
|
from app.models import Conversation as ConversationModel
|
||||||
from app.schemas.conversation_schema import ConversationCreate
|
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
|
|
||||||
conversation_service = ConversationService(self.db)
|
conversation_service = ConversationService(self.db)
|
||||||
@@ -945,9 +952,9 @@ class DraftRunService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _load_conversation_history(
|
async def _load_conversation_history(
|
||||||
self,
|
self,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
max_history: int = 10
|
max_history: int = 10
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""加载会话历史消息
|
"""加载会话历史消息
|
||||||
|
|
||||||
@@ -984,13 +991,13 @@ class DraftRunService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def _save_conversation_message(
|
async def _save_conversation_message(
|
||||||
self,
|
self,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
assistant_message: str,
|
assistant_message: str,
|
||||||
meta_data: dict,
|
meta_data: dict,
|
||||||
app_id: Optional[uuid.UUID] = None,
|
app_id: Optional[uuid.UUID] = None,
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||||||
|
|
||||||
@@ -1100,10 +1107,10 @@ class DraftRunService:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _replace_variables(
|
def _replace_variables(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
values: Dict[str, Any],
|
values: Dict[str, Any],
|
||||||
definitions: List[Dict[str, Any]]
|
definitions: List[Dict[str, Any]]
|
||||||
) -> str:
|
) -> str:
|
||||||
"""替换文本中的变量
|
"""替换文本中的变量
|
||||||
|
|
||||||
@@ -1129,8 +1136,8 @@ class DraftRunService:
|
|||||||
# 替换变量(支持多种格式)
|
# 替换变量(支持多种格式)
|
||||||
placeholders = [
|
placeholders = [
|
||||||
f"{{{{{var_name}}}}}", # {{var_name}}
|
f"{{{{{var_name}}}}}", # {{var_name}}
|
||||||
f"{{{var_name}}}", # {var_name}
|
f"{{{var_name}}}", # {var_name}
|
||||||
f"${{{var_name}}}", # ${var_name}
|
f"${{{var_name}}}", # ${var_name}
|
||||||
]
|
]
|
||||||
|
|
||||||
for placeholder in placeholders:
|
for placeholder in placeholders:
|
||||||
@@ -1142,21 +1149,22 @@ class DraftRunService:
|
|||||||
# ==================== 多模型对比试运行 ====================
|
# ==================== 多模型对比试运行 ====================
|
||||||
|
|
||||||
async def run_compare(
|
async def run_compare(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
models: List[Dict[str, Any]],
|
models: List[Dict[str, Any]],
|
||||||
message: str,
|
message: str,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
parallel: bool = True,
|
parallel: bool = True,
|
||||||
timeout: int = 60,
|
timeout: int = 60,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True,
|
web_search: bool = True,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
|
files: list[FileInput] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""多模型对比试运行
|
"""多模型对比试运行
|
||||||
|
|
||||||
@@ -1206,7 +1214,8 @@ class DraftRunService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=web_search,
|
web_search=web_search,
|
||||||
memory=memory
|
memory=memory,
|
||||||
|
files=files
|
||||||
),
|
),
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
@@ -1221,7 +1230,7 @@ class DraftRunService:
|
|||||||
"model_config_id": model_info["model_config_id"],
|
"model_config_id": model_info["model_config_id"],
|
||||||
"model_name": model_info["model_config"].name,
|
"model_name": model_info["model_config"].name,
|
||||||
"label": model_info["label"],
|
"label": model_info["label"],
|
||||||
"conversation_id":result['conversation_id'],
|
"conversation_id": result['conversation_id'],
|
||||||
"parameters_used": model_info["parameters"],
|
"parameters_used": model_info["parameters"],
|
||||||
"message": result.get("message"),
|
"message": result.get("message"),
|
||||||
"usage": usage,
|
"usage": usage,
|
||||||
@@ -1349,21 +1358,22 @@ class DraftRunService:
|
|||||||
return agent_config, original_params
|
return agent_config, original_params
|
||||||
|
|
||||||
async def run_compare_stream(
|
async def run_compare_stream(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
models: List[Dict[str, Any]],
|
models: List[Dict[str, Any]],
|
||||||
message: str,
|
message: str,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True,
|
web_search: bool = True,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
parallel: bool = True,
|
parallel: bool = True,
|
||||||
timeout: int = 60
|
timeout: int = 60,
|
||||||
|
files: list[FileInput] | None = None
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""多模型对比试运行(流式返回)
|
"""多模型对比试运行(流式返回)
|
||||||
|
|
||||||
@@ -1383,6 +1393,7 @@ class DraftRunService:
|
|||||||
memory: 是否启用记忆
|
memory: 是否启用记忆
|
||||||
parallel: 是否并行执行
|
parallel: 是否并行执行
|
||||||
timeout: 超时时间(秒)
|
timeout: 超时时间(秒)
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: SSE 格式的事件数据
|
str: SSE 格式的事件数据
|
||||||
@@ -1431,17 +1442,18 @@ class DraftRunService:
|
|||||||
try:
|
try:
|
||||||
# 流式调用单个模型
|
# 流式调用单个模型
|
||||||
async for event_str in self.run_stream(
|
async for event_str in self.run_stream(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
model_config=model_info["model_config"],
|
model_config=model_info["model_config"],
|
||||||
message=message,
|
message=message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=model_conversation_id,
|
conversation_id=model_conversation_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=web_search,
|
web_search=web_search,
|
||||||
memory=memory
|
memory=memory,
|
||||||
|
files=files
|
||||||
):
|
):
|
||||||
# 解析原始事件
|
# 解析原始事件
|
||||||
try:
|
try:
|
||||||
@@ -1661,15 +1673,15 @@ class DraftRunService:
|
|||||||
|
|
||||||
|
|
||||||
async def draft_run(
|
async def draft_run(
|
||||||
db: Session,
|
db: Session,
|
||||||
*,
|
*,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
message: str,
|
message: str,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
kb_ids: Optional[List[str]] = None,
|
kb_ids: Optional[List[str]] = None,
|
||||||
similarity_threshold: float = 0.7,
|
similarity_threshold: float = 0.7,
|
||||||
top_k: int = 3
|
top_k: int = 3
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""试运行 Agent(便捷函数)
|
"""试运行 Agent(便捷函数)
|
||||||
|
|
||||||
@@ -1696,4 +1708,3 @@ async def draft_run(
|
|||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
top_k=top_k
|
top_k=top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -313,13 +313,16 @@ class MemoryAgentService:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Load configuration from database with workspace fallback
|
# Load configuration from database with workspace fallback
|
||||||
|
# Use a separate database session to avoid transaction failures
|
||||||
try:
|
try:
|
||||||
config_service = MemoryConfigService(db)
|
from app.db import get_db_context
|
||||||
memory_config = config_service.load_memory_config(
|
with get_db_context() as config_db:
|
||||||
config_id=config_id,
|
config_service = MemoryConfigService(config_db)
|
||||||
workspace_id=workspace_id,
|
memory_config = config_service.load_memory_config(
|
||||||
service_name="MemoryAgentService"
|
config_id=config_id,
|
||||||
)
|
workspace_id=workspace_id,
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||||
@@ -454,12 +457,15 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
config_service = MemoryConfigService(db)
|
# Use a separate database session to avoid transaction failures
|
||||||
memory_config = config_service.load_memory_config(
|
from app.db import get_db_context
|
||||||
config_id=config_id,
|
with get_db_context() as config_db:
|
||||||
workspace_id=workspace_id,
|
config_service = MemoryConfigService(config_db)
|
||||||
service_name="MemoryAgentService"
|
memory_config = config_service.load_memory_config(
|
||||||
)
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
config_load_time = time.time() - config_load_start
|
config_load_time = time.time() - config_load_start
|
||||||
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
|
|||||||
@@ -364,7 +364,8 @@ class MemoryReflectionService:
|
|||||||
reflexion_range_value = config_data.get("reflexion_range")
|
reflexion_range_value = config_data.get("reflexion_range")
|
||||||
if reflexion_range_value is None or reflexion_range_value == "":
|
if reflexion_range_value is None or reflexion_range_value == "":
|
||||||
reflexion_range_value = "partial"
|
reflexion_range_value = "partial"
|
||||||
# Map legacy/invalid values to valid enum values
|
|
||||||
|
# Map legacy/invalid values to valid enum values
|
||||||
reflexion_range_mapping = {
|
reflexion_range_mapping = {
|
||||||
"retrieval": "partial", # Map old 'retrieval' to 'partial'
|
"retrieval": "partial", # Map old 'retrieval' to 'partial'
|
||||||
"partial": "partial",
|
"partial": "partial",
|
||||||
@@ -378,13 +379,19 @@ class MemoryReflectionService:
|
|||||||
baseline_value = "TIME"
|
baseline_value = "TIME"
|
||||||
baseline = ReflectionBaseline(baseline_value)
|
baseline = ReflectionBaseline(baseline_value)
|
||||||
|
|
||||||
# iteration_period =
|
# iteration_period
|
||||||
iteration_period = config_data.get("iteration_period", 24)
|
iteration_period = config_data.get("iteration_period", 24)
|
||||||
if isinstance(iteration_period, str):
|
if isinstance(iteration_period, str):
|
||||||
try:
|
try:
|
||||||
iteration_period = int(iteration_period)
|
iteration_period = int(iteration_period)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
iteration_period = 24 # 默认24小时
|
iteration_period = 24 # 默认24小时
|
||||||
|
|
||||||
|
# 获取 model_id 并转换为字符串(如果是 UUID 对象)
|
||||||
|
reflection_model_id = config_data.get("reflection_model_id", "")
|
||||||
|
if reflection_model_id:
|
||||||
|
reflection_model_id = str(reflection_model_id)
|
||||||
|
|
||||||
return ReflectionConfig(
|
return ReflectionConfig(
|
||||||
enabled=config_data.get("enable_self_reflexion", False),
|
enabled=config_data.get("enable_self_reflexion", False),
|
||||||
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
||||||
@@ -392,7 +399,7 @@ class MemoryReflectionService:
|
|||||||
baseline=baseline,
|
baseline=baseline,
|
||||||
memory_verify=config_data.get("memory_verify", False),
|
memory_verify=config_data.get("memory_verify", False),
|
||||||
quality_assessment=config_data.get("quality_assessment", False),
|
quality_assessment=config_data.get("quality_assessment", False),
|
||||||
model_id=config_data.get("reflection_model_id", "")
|
model_id=reflection_model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_reflection_engine(
|
async def _execute_reflection_engine(
|
||||||
|
|||||||
@@ -211,7 +211,8 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
|||||||
|
|
||||||
// Process parser_config data, set default values if not present
|
// Process parser_config data, set default values if not present
|
||||||
const recordAny = record as any;
|
const recordAny = record as any;
|
||||||
baseValues.parser_config = record.parser_config || {
|
baseValues.parser_config = {
|
||||||
|
...record.parser_config,
|
||||||
graphrag: {
|
graphrag: {
|
||||||
use_graphrag: false,
|
use_graphrag: false,
|
||||||
scene_name: '',
|
scene_name: '',
|
||||||
@@ -219,6 +220,7 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
|||||||
method: 'general',
|
method: 'general',
|
||||||
resolution: false,
|
resolution: false,
|
||||||
community: false,
|
community: false,
|
||||||
|
...(record.parser_config?.graphrag || {})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -656,7 +658,7 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
|||||||
{currentType !== 'Folder' && dynamicTypeList.map((tp) => {
|
{currentType !== 'Folder' && dynamicTypeList.map((tp) => {
|
||||||
const fieldKey = typeToFieldKey(tp);
|
const fieldKey = typeToFieldKey(tp);
|
||||||
// When tp is 'llm', merge llm and chat options
|
// When tp is 'llm', merge llm and chat options
|
||||||
const options = tp.toLowerCase() === 'llm'
|
const options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
|
||||||
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
|
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
|
||||||
: modelOptionsByType[tp] || [];
|
: modelOptionsByType[tp] || [];
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user