feat(DraftRun): support multimodal input for model comparison (#353)
This commit is contained in:
@@ -802,7 +802,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -488,7 +488,7 @@ class DraftRunCompareRequest(BaseModel):
|
||||
max_length=5,
|
||||
description="要对比的模型列表(1-5个)"
|
||||
)
|
||||
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
parallel: bool = Field(True, description="是否并行执行")
|
||||
stream: bool = Field(False, description="是否流式返回")
|
||||
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
||||
|
||||
@@ -16,26 +16,26 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
"""知识库检索工具输入参数"""
|
||||
query: str = Field(description="需要检索的问题或关键词")
|
||||
@@ -48,9 +48,12 @@ class WebSearchInput(BaseModel):
|
||||
|
||||
class LongTermMemoryInput(BaseModel):
|
||||
"""长期记忆工具输入参数"""
|
||||
question: str = Field(description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||
question: str = Field(
|
||||
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||
|
||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,user_rag_memory_id: Optional[str] = None):
|
||||
|
||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None):
|
||||
"""创建记忆工具,
|
||||
|
||||
|
||||
@@ -66,6 +69,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
|
||||
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||||
|
||||
@tool(args_schema=LongTermMemoryInput)
|
||||
def long_term_memory(question: str) -> str:
|
||||
"""
|
||||
@@ -133,6 +137,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
except Exception as e:
|
||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||
return f"记忆检索失败: {str(e)}"
|
||||
|
||||
return long_term_memory
|
||||
|
||||
|
||||
@@ -189,6 +194,7 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
||||
检索到的相关知识内容
|
||||
"""
|
||||
logger.info(f"创建知识库检索工具,用户:{user_id}")
|
||||
|
||||
@tool(args_schema=KnowledgeRetrievalInput)
|
||||
def knowledge_retrieval_tool(query: str) -> str:
|
||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||
@@ -200,7 +206,6 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
||||
检索到的相关知识内容
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
|
||||
@@ -226,6 +231,7 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
||||
|
||||
return knowledge_retrieval_tool
|
||||
|
||||
|
||||
class DraftRunService:
|
||||
"""试运行服务类"""
|
||||
|
||||
@@ -296,7 +302,6 @@ class DraftRunService:
|
||||
agent_config=agent_config
|
||||
)
|
||||
|
||||
|
||||
items_params = variables
|
||||
system_prompt = render_prompt_message(
|
||||
agent_config.system_prompt, # 修正拼写错误
|
||||
@@ -358,7 +363,8 @@ class DraftRunService:
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
@@ -391,7 +397,8 @@ class DraftRunService:
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
@@ -626,14 +633,14 @@ class DraftRunService:
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
@@ -658,7 +665,8 @@ class DraftRunService:
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
@@ -863,7 +871,6 @@ class DraftRunService:
|
||||
BusinessException: 当指定的会话不存在时
|
||||
"""
|
||||
from app.models import Conversation as ConversationModel
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
@@ -1157,6 +1164,7 @@ class DraftRunService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
web_search: bool = True,
|
||||
memory: bool = True,
|
||||
files: list[FileInput] | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""多模型对比试运行
|
||||
|
||||
@@ -1206,7 +1214,8 @@ class DraftRunService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=web_search,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
files=files
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
@@ -1363,7 +1372,8 @@ class DraftRunService:
|
||||
web_search: bool = True,
|
||||
memory: bool = True,
|
||||
parallel: bool = True,
|
||||
timeout: int = 60
|
||||
timeout: int = 60,
|
||||
files: list[FileInput] | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""多模型对比试运行(流式返回)
|
||||
|
||||
@@ -1383,6 +1393,7 @@ class DraftRunService:
|
||||
memory: 是否启用记忆
|
||||
parallel: 是否并行执行
|
||||
timeout: 超时时间(秒)
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
@@ -1441,7 +1452,8 @@ class DraftRunService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=web_search,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
files=files
|
||||
):
|
||||
# 解析原始事件
|
||||
try:
|
||||
@@ -1696,4 +1708,3 @@ async def draft_run(
|
||||
similarity_threshold=similarity_threshold,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user