* feat(web): add PageEmpty component
* feat(web): add PageTabs component
* feat(web): add PageEmpty component
* feat(web): add PageTabs component
* feat(prompt): add history tracking for prompt releases
* feat(web): add prompt menu
* refactor: The PageScrollList component supports two generic parameters
* feat(web): BodyWrapper compoent update PageLoading
* feat(web): add Ontology menu
* feat(web): memory management add scene
* feat(tasks): add celery task configuration for periodic jobs
- Add ignore_result=True to prevent storing results for periodic tasks
- Set max_retries=0 to skip failed periodic tasks without retry attempts
- Configure acks_late=False for immediate acknowledgment in beat tasks
- Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s)
- Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s)
- Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s)
- Improve task reliability and resource management for scheduled jobs
* feat(sandbox): add Node.js code execution support to sandbox
* Release/v0.2.2 (#260)
* [modify] migration script
* [add] migration script
* fix(web): change form message
* fix(web): the memoryContent field is compatible with numbers and strings
* feat(web): code node hidden
* fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.
---------
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
* Feature/ontology class clean (#249)
* [add] Complete ontology engineering feature implementation
* [add] Add ontology feature integration and validation utilities
* [add] Add OWL validator and validation utilities
* [fix] Add missing render_ontology_extraction_prompt function
* [fix]Add dependencies, fix functionality
* [add] migration script
* feat(celery): add dedicated periodic tasks worker and queue (#261)
* fix(web): conflict resolve
* Fix/v022 bug (#263)
* [fix]Fix the issue of inconsistent language in explicit and episodic memory.
* [fix]Fix the issue of inconsistent language in explicit and episodic memory.
* [add]Add scene_id
* [fix]Based on the AI review to fix the code
* Fix/develop memory reflex (#265)
* 遗漏的历史映射
* 遗漏的历史映射
* 反思后台报错处理
* [add] migration script
* fix: chat conversation_id add node_start
* feat(web): show code node
* fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated
* feat(web): RadioGroupCard support block mode
* feat(web): create space add icon
* feat(app and model): token consumption statistics
* Add/develop memory (#264)
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 新增长期记忆功能
* 新增长期记忆功能
* 新增长期记忆功能
* 知识库检索多余字段
* 长期
* feat(app and model): token consumption statistics of the cluster
* memory_BUG_fix
* fix(web): prompt history remove pageLoading
* fix(prompt): remove hard-coded import of prompt file paths (#279)
* Fix/develop memory bug (#274)
* 遗漏的历史映射
* 遗漏的历史映射
* fix_timeline_memories
* fix(web): update retrieve_type key
* Fix/develop memory bug (#276)
* 遗漏的历史映射
* 遗漏的历史映射
* fix_timeline_memories
* fix_timeline_memories
* write_gragp/bug_fix
* write_gragp/bug_fix
* write_gragp/bug_fix
* chore(celery): disable periodic task scheduling
* fix(prompt): remove hard-coded import of prompt file paths
---------
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>
* fix(web): remove delete confirm content
* refactor(workflow): relocate template directory into workflow
* feat(memory): add long-term storage task routing and batching
* fix(web): PageScrollList loading update
* fix(web): PageScrollList loading update
* Ontology v1 bug (#291)
* [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object
* [fix]Fix the "end_user" return to be sorted by update time.
* [fix]Set the default values of the memory configuration model based on the spatial model.
* [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id
* [fix]Fix the "end_user" return to be sorted by update time.
* [fix]
* fix(memory): add Redis session validation
- Add macOS fork() safety configuration in celery_app.py to prevent initialization issues
- Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully
- Add null/False checks in memory_long_term_storage to prevent processing empty Redis results
- Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data
- Initialize redis_messages variable in window_dialogue for consistency
- Add debug logging when no existing session found in Redis for better troubleshooting
- Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants
- Improve error handling when Redis returns False or empty results instead of crashing
* fix(web): PageScrollList style update
* fix(workflow): fix argument passing in code execution nodes
* fix(web): prompt add disabled
* fix(web): space icon required
* feat(app): modify the key of the token
* fix(fix the key of the app's token):
* fix(workflow): switch code input encoding to base64+URL encoding
* [add]The main project adds multi-API Key load balancing.
* [changes]Attribute security access, secure numerical conversion, unified use of local variables
* fix(web): save add session update
* fix(web): language editor support paste
* [changes]Active status filtering logic, API Key selection strategy
* memory_BUG
* memory_BUG_long_term
* [changes]
* memory_BUG_long_term
* memory_BUG_long_term
* Fix/release memory bug (#306)
* memory_BUG_fix
* memory_BUG
* memory_BUG_long_term
* memory_BUG_long_term
* memory_BUG_long_term
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios
* fix(web): replace code editor
* [changes]Modify the description of the time for the recent event
* [changes]Modify the code based on the AI review
* feat(web): update memory config ontology api
* fix(web): ui update
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* feat(workflow): add token usage statistics for question classifier and parameter extraction
* feat(web): move prompt menu
* Multiple independent transactions - single transaction
* Multiple independent transactions - single transaction
* Multiple independent transactions - single transaction
* Multiple independent transactions - single transaction
* Write Missing None (#321)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Fix/release memory bug (#324)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
* redis update
* redis update
* redis update
* redis update
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Fix/writer memory bug (#326)
* [fix]Fix the bug
* [fix]Fix the bug
* [fix]Correct the direction indication.
* fix(web): markdown table ui update
* Fix/release memory bug (#332)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
* redis update
* redis update
* redis update
* redis update
* writer_dup_bug/fix
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Fix/fact summary (#333)
* [fix]Disable the contents related to fact_summary
* [fix]Disable the contents related to fact_summary
* [fix]Modify the code based on the AI review
* Fix/release memory bug (#335)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
* redis update
* redis update
* redis update
* redis update
* writer_dup_bug/fix
* writer_graph_bug/fix
* writer_graph_bug/fix
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Revert "feat(web): move prompt menu"
This reverts commit 9e6e8f50f8.
* fix(web): ui update
* fix(web): update text
* fix(web): ui update
* fix(model): change the "vl" model type of dashscope to "chat"
* fix(model): change the "vl" model type of dashscope to "chat"
---------
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: Eternity <1533512157@qq.com>
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
Co-authored-by: 乐力齐 <162269739+lanceyq@users.noreply.github.com>
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: lixinyue <2569494688@qq.com>
Co-authored-by: Eternity <61316157+myhMARS@users.noreply.github.com>
Co-authored-by: lanceyq <1982376970@qq.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
1630 lines
61 KiB
Python
1630 lines
61 KiB
Python
"""
|
||
试运行服务
|
||
|
||
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
|
||
"""
|
||
import asyncio
|
||
import datetime
|
||
import json
|
||
import time
|
||
import uuid
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||
|
||
from app.celery_app import celery_app
|
||
from app.core.error_codes import BizCode
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.logging_config import get_business_logger
|
||
from app.core.rag.nlp.search import knowledge_retrieval
|
||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||
from app.repositories.model_repository import ModelApiKeyRepository
|
||
from app.repositories.tool_repository import ToolRepository
|
||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||
from app.services import task_service
|
||
from app.services.langchain_tool_server import Search
|
||
from app.services.memory_agent_service import MemoryAgentService
|
||
from app.services.model_parameter_merger import ModelParameterMerger
|
||
from app.services.tool_service import ToolService
|
||
from langchain.tools import tool
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import Session
|
||
|
||
logger = get_business_logger()
|
||
class KnowledgeRetrievalInput(BaseModel):
|
||
"""知识库检索工具输入参数"""
|
||
query: str = Field(description="需要检索的问题或关键词")
|
||
|
||
|
||
class WebSearchInput(BaseModel):
|
||
"""网络搜索工具输入参数"""
|
||
query: str = Field(description="需要搜索的问题或关键词")
|
||
|
||
|
||
class LongTermMemoryInput(BaseModel):
|
||
"""长期记忆工具输入参数"""
|
||
question: str = Field(description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||
|
||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,user_rag_memory_id: Optional[str] = None):
|
||
"""创建记忆工具,
|
||
|
||
|
||
Args:
|
||
memory_config: 记忆配置
|
||
end_user_id: 用户ID
|
||
storage_type: 存储类型(可选)
|
||
|
||
Returns:
|
||
长期记忆工具
|
||
"""
|
||
# search_switch = memory_config.get("search_switch", "2")
|
||
config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None)
|
||
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||
@tool(args_schema=LongTermMemoryInput)
|
||
def long_term_memory(question: str) -> str:
|
||
"""
|
||
从用户的历史记忆中检索相关信息。这是一个强大的工具,可以帮助你了解用户的背景、偏好和历史对话内容。
|
||
|
||
以下场景不需要使用此工具:
|
||
1. 情绪/社交问候场景(如"你好"、"谢谢"、"再见"等简单寒暄)
|
||
2. 纯任务性场景(如"帮我写代码"、"翻译这段文字"等不需要历史上下文的任务)
|
||
3. 处理外部内容时(如用户提供的文本、代码、RAG数据等,这些内容本身已经包含所需信息)
|
||
|
||
除上述场景外的所有其他情况都应该使用此工具,特别是:
|
||
- 用户询问个人信息或历史对话内容
|
||
- 需要了解用户偏好、习惯或背景
|
||
- 用户提到"之前"、"上次"、"记得"等涉及历史的词汇
|
||
- 需要个性化回复或基于历史上下文的建议
|
||
- 用户询问关于自己的任何信息
|
||
|
||
需要对question改写/优化:
|
||
需要重点关注一以下几点
|
||
- 相关的关键词,保持原问题的核心语义不变, 根据上下文,使问题更具体、更清晰,将模糊的表达转换为明确的搜索词
|
||
- 使用同义词或相关术语扩展查询
|
||
Args:
|
||
question: question改写之后的内容
|
||
|
||
Returns:
|
||
检索到的历史记忆内容
|
||
"""
|
||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||
try:
|
||
from app.db import get_db
|
||
db = next(get_db())
|
||
try:
|
||
memory_content = asyncio.run(
|
||
MemoryAgentService().read_memory(
|
||
end_user_id=end_user_id,
|
||
message=question,
|
||
history=[],
|
||
search_switch="2",
|
||
config_id=config_id,
|
||
db=db,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id
|
||
)
|
||
)
|
||
task = celery_app.send_task(
|
||
"app.core.memory.agent.read_message",
|
||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||
)
|
||
result = task_service.get_task_memory_read_result(task.id)
|
||
status = result.get("status")
|
||
logger.info(f"读取任务状态:{status}")
|
||
if memory_content:
|
||
memory_content = memory_content['answer']
|
||
|
||
finally:
|
||
db.close()
|
||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||
|
||
logger.info(
|
||
"长期记忆检索成功",
|
||
extra={
|
||
"end_user_id": end_user_id,
|
||
"content_length": len(str(memory_content))
|
||
}
|
||
)
|
||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||
except Exception as e:
|
||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
return f"记忆检索失败: {str(e)}"
|
||
return long_term_memory
|
||
|
||
|
||
def create_web_search_tool(web_search_config: Dict[str, Any]):
|
||
"""创建网络搜索工具
|
||
|
||
Args:
|
||
web_search_config: 网络搜索配置
|
||
|
||
Returns:
|
||
网络搜索工具
|
||
"""
|
||
logger.info("创建网络搜索工具")
|
||
|
||
@tool(args_schema=WebSearchInput)
|
||
def web_search_tool(query: str) -> str:
|
||
"""从互联网搜索最新信息。当用户的问题需要实时信息、最新新闻或网络资料时,使用此工具进行搜索。
|
||
|
||
Args:
|
||
query: 需要搜索的问题或关键词
|
||
|
||
Returns:
|
||
搜索到的相关网络信息
|
||
"""
|
||
try:
|
||
logger.info(f"执行网络搜索: {query}")
|
||
|
||
# 调用搜索服务
|
||
search_result = Search(query)
|
||
logger.info(
|
||
"网络搜索成功",
|
||
extra={
|
||
"query": query,
|
||
"result_length": len(search_result)
|
||
}
|
||
)
|
||
|
||
return f"搜索到以下网络信息:\n\n{search_result}"
|
||
|
||
except Exception as e:
|
||
logger.error("网络搜索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
return f"搜索失败: {str(e)}"
|
||
|
||
return web_search_tool
|
||
|
||
|
||
def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
|
||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||
|
||
Args:
|
||
query: 需要检索的问题或关键词
|
||
|
||
Returns:
|
||
检索到的相关知识内容
|
||
"""
|
||
logger.info(f"创建知识库检索工具,用户:{user_id}")
|
||
@tool(args_schema=KnowledgeRetrievalInput)
|
||
def knowledge_retrieval_tool(query: str) -> str:
|
||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||
|
||
Args:
|
||
query: 需要检索的问题或关键词
|
||
|
||
Returns:
|
||
检索到的相关知识内容
|
||
"""
|
||
|
||
|
||
try:
|
||
|
||
retrieve_chunks_result = knowledge_retrieval(query, kb_config)
|
||
if retrieve_chunks_result:
|
||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||
context = '\n\n'.join(retrieval_knowledge)
|
||
logger.info(
|
||
"知识库检索成功",
|
||
extra={
|
||
"kb_ids": kb_ids,
|
||
"result_count": len(retrieval_knowledge),
|
||
"total_length": len(context)
|
||
}
|
||
)
|
||
|
||
return f"检索到以下相关信息:\n\n{context}"
|
||
else:
|
||
logger.warning("知识库检索未找到结果")
|
||
return "未找到相关信息"
|
||
except Exception as e:
|
||
logger.error("知识库检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
return f"检索失败: {str(e)}"
|
||
|
||
return knowledge_retrieval_tool
|
||
|
||
class DraftRunService:
|
||
"""试运行服务类"""
|
||
|
||
def __init__(self, db: Session):
|
||
"""初始化试运行服务
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
"""
|
||
self.db = db
|
||
|
||
async def run(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
model_config: ModelConfig,
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True,
|
||
memory: bool = True,
|
||
sub_agent: bool = False
|
||
) -> Dict[str, Any]:
|
||
"""执行试运行(使用 LangChain Agent)
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
model_config: 模型配置
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID(必须,用于会话隔离)
|
||
conversation_id: 会话ID(用于多轮对话)
|
||
user_id: 用户ID
|
||
variables: 自定义变量参数值
|
||
|
||
Returns:
|
||
Dict: 包含 AI 回复和元数据的字典
|
||
"""
|
||
memory_flag=False
|
||
|
||
print('===========',storage_type)
|
||
|
||
print(user_id)
|
||
if variables == None: variables = {}
|
||
from app.core.agent.langchain_agent import LangChainAgent
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 1. 获取 API Key 配置
|
||
api_key_config = await self._get_api_key(model_config.id)
|
||
logger.debug(
|
||
"API Key 配置获取成功",
|
||
extra={
|
||
"model_name": api_key_config["model_name"],
|
||
"has_api_key": bool(api_key_config["api_key"]),
|
||
"has_api_base": bool(api_key_config.get("api_base"))
|
||
}
|
||
)
|
||
|
||
# 2. 合并模型参数
|
||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||
model_config=model_config,
|
||
agent_config=agent_config
|
||
)
|
||
|
||
|
||
items_params=variables
|
||
system_prompt = render_prompt_message(
|
||
agent_config.system_prompt, # 修正拼写错误
|
||
PromptMessageRole.USER,
|
||
items_params
|
||
)
|
||
|
||
# 3. 处理系统提示词(支持变量替换)
|
||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||
print('系统提示词:',system_prompt)
|
||
|
||
# 4. 准备工具列表
|
||
tools = []
|
||
|
||
tool_service = ToolService(self.db)
|
||
|
||
# 从配置中获取启用的工具
|
||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||
for tool_config in agent_config.tools:
|
||
print("+"*50)
|
||
print(f"agent_config:{agent_config}")
|
||
print(f"tool_config:{tool_config}")
|
||
if tool_config.get("enabled", False):
|
||
# 根据工具名称查找工具实例
|
||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||
ToolRepository.get_tenant_id_by_workspace_id(
|
||
self.db, str(workspace_id)))
|
||
if tool_instance:
|
||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||
continue
|
||
# 转换为LangChain工具
|
||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||
tools.append(langchain_tool)
|
||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||
web_tools = agent_config.tools
|
||
web_search_choice = web_tools.get("web_search", {})
|
||
web_search_enable = web_search_choice.get("enabled", False)
|
||
if web_search:
|
||
if web_search_enable:
|
||
search_tool = create_web_search_tool({})
|
||
tools.append(search_tool)
|
||
|
||
logger.debug(
|
||
"已添加网络搜索工具",
|
||
extra={
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
|
||
# 添加知识库检索工具
|
||
if agent_config.knowledge_retrieval:
|
||
kb_config = agent_config.knowledge_retrieval
|
||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||
if kb_ids:
|
||
# 创建知识库检索工具
|
||
kb_tool = create_knowledge_retrieval_tool(kb_config,kb_ids,user_id)
|
||
tools.append(kb_tool)
|
||
|
||
logger.debug(
|
||
"已添加知识库检索工具",
|
||
extra={
|
||
"kb_ids": kb_ids,
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
|
||
# 添加长期记忆工具
|
||
if memory:
|
||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||
memory_flag=True
|
||
|
||
memory_config = agent_config.memory
|
||
if user_id:
|
||
# 创建长期记忆工具
|
||
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
|
||
tools.append(memory_tool)
|
||
|
||
logger.debug(
|
||
"已添加长期记忆工具",
|
||
extra={
|
||
"user_id": user_id,
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
|
||
# 4. 创建 LangChain Agent
|
||
agent = LangChainAgent(
|
||
model_name=api_key_config["model_name"],
|
||
api_key=api_key_config["api_key"],
|
||
provider=api_key_config.get("provider", "openai"),
|
||
api_base=api_key_config.get("api_base"),
|
||
temperature=effective_params.get("temperature", 0.7),
|
||
max_tokens=effective_params.get("max_tokens", 2000),
|
||
system_prompt=system_prompt,
|
||
tools=tools,
|
||
)
|
||
|
||
# 5. 处理会话ID(创建或验证)
|
||
conversation_id = await self._ensure_conversation(
|
||
conversation_id=conversation_id,
|
||
app_id=agent_config.app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id
|
||
)
|
||
|
||
# 6. 加载历史消息
|
||
history = []
|
||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||
history = await self._load_conversation_history(
|
||
conversation_id=conversation_id,
|
||
max_history=agent_config.memory.get("max_history", 10)
|
||
)
|
||
|
||
# 6. 知识库检索
|
||
context = None
|
||
|
||
logger.debug(
|
||
"准备调用 LangChain Agent",
|
||
extra={
|
||
"model": api_key_config["model_name"],
|
||
"has_history": bool(history),
|
||
"has_context": bool(context)
|
||
}
|
||
)
|
||
|
||
memory_config_= agent_config.memory
|
||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||
|
||
# 7. 调用 Agent
|
||
result = await agent.chat(
|
||
message=message,
|
||
history=history,
|
||
context=context,
|
||
end_user_id=user_id,
|
||
config_id=config_id,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id,
|
||
memory_flag=memory_flag
|
||
)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
# 8. 保存会话消息
|
||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||
await self._save_conversation_message(
|
||
conversation_id=conversation_id,
|
||
user_message=message,
|
||
assistant_message=result["content"],
|
||
app_id=agent_config.app_id,
|
||
user_id=user_id,
|
||
meta_data={
|
||
"usage": result.get("usage", {
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
"total_tokens": 0
|
||
})
|
||
}
|
||
)
|
||
|
||
response = {
|
||
"message": result["content"],
|
||
"conversation_id": conversation_id,
|
||
"usage": result.get("usage", {
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
"total_tokens": 0
|
||
}),
|
||
"elapsed_time": elapsed_time
|
||
}
|
||
|
||
logger.info(
|
||
"试运行完成",
|
||
extra={
|
||
"model": model_config.name,
|
||
"elapsed_time": elapsed_time,
|
||
"message_length": len(result["content"]),
|
||
"total_tokens": result.get("usage", {}).get("total_tokens", 0)
|
||
}
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error("LangChain Agent 调用失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
||
|
||
async def run_stream(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
model_config: ModelConfig,
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True, # 布尔类型默认值
|
||
memory: bool = True, # 布尔类型默认值
|
||
sub_agent: bool = False # 是否是作为子Agent运行
|
||
|
||
) -> AsyncGenerator[str, None]:
|
||
"""执行试运行(流式返回,使用 LangChain Agent)
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
model_config: 模型配置
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID(必须,用于会话隔离)
|
||
conversation_id: 会话ID(用于多轮对话)
|
||
user_id: 用户ID
|
||
variables: 自定义变量参数值
|
||
|
||
Yields:
|
||
str: SSE 格式的事件数据
|
||
"""
|
||
memory_flag=False
|
||
if variables==None:variables={}
|
||
|
||
from app.core.agent.langchain_agent import LangChainAgent
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 1. 获取 API Key 配置
|
||
api_key_config = await self._get_api_key(model_config.id)
|
||
|
||
# 2. 合并模型参数
|
||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||
model_config=model_config,
|
||
agent_config=agent_config
|
||
)
|
||
|
||
items_params=variables
|
||
|
||
system_prompt = render_prompt_message(
|
||
agent_config.system_prompt, # 修正拼写错误
|
||
PromptMessageRole.USER,
|
||
items_params
|
||
)
|
||
|
||
# 3. 处理系统提示词(支持变量替换)
|
||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||
|
||
# 4. 准备工具列表
|
||
tools = []
|
||
|
||
tool_service = ToolService(self.db)
|
||
|
||
# 从配置中获取启用的工具
|
||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||
for tool_config in agent_config.tools:
|
||
# print("+"*50)
|
||
# print(f"agent_config:{agent_config}")
|
||
# print(f"tool_config:{tool_config}")
|
||
if tool_config.get("enabled", False):
|
||
# 根据工具名称查找工具实例
|
||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||
ToolRepository.get_tenant_id_by_workspace_id(
|
||
self.db, str(workspace_id)))
|
||
if tool_instance:
|
||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||
continue
|
||
# 转换为LangChain工具
|
||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||
tools.append(langchain_tool)
|
||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||
web_tools = agent_config.tools
|
||
web_search_choice = web_tools.get("web_search", {})
|
||
web_search_enable = web_search_choice.get("enabled", False)
|
||
if web_search:
|
||
if web_search_enable:
|
||
search_tool = create_web_search_tool({})
|
||
tools.append(search_tool)
|
||
|
||
logger.debug(
|
||
"已添加网络搜索工具",
|
||
extra={
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
|
||
|
||
# 添加知识库检索工具
|
||
if agent_config.knowledge_retrieval:
|
||
kb_config = agent_config.knowledge_retrieval
|
||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||
if kb_ids:
|
||
# 创建知识库检索工具
|
||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||
tools.append(kb_tool)
|
||
|
||
logger.debug(
|
||
"已添加知识库检索工具",
|
||
extra={
|
||
"kb_ids": kb_ids,
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
# 添加长期记忆工具
|
||
if memory:
|
||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||
memory_flag= True
|
||
memory_config = agent_config.memory
|
||
if user_id:
|
||
# 创建长期记忆工具
|
||
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id)
|
||
tools.append(memory_tool)
|
||
|
||
logger.debug(
|
||
"已添加长期记忆工具",
|
||
extra={
|
||
"user_id": user_id,
|
||
"tool_count": len(tools)
|
||
}
|
||
)
|
||
|
||
|
||
# 4. 创建 LangChain Agent
|
||
agent = LangChainAgent(
|
||
model_name=api_key_config["model_name"],
|
||
api_key=api_key_config["api_key"],
|
||
provider=api_key_config.get("provider", "openai"),
|
||
api_base=api_key_config.get("api_base"),
|
||
temperature=effective_params.get("temperature", 0.7),
|
||
max_tokens=effective_params.get("max_tokens", 2000),
|
||
system_prompt=system_prompt,
|
||
tools=tools,
|
||
streaming=True
|
||
)
|
||
|
||
# 5. 处理会话ID(创建或验证)
|
||
conversation_id = await self._ensure_conversation(
|
||
conversation_id=conversation_id,
|
||
app_id=agent_config.app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id
|
||
)
|
||
|
||
# 6. 加载历史消息
|
||
history = []
|
||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||
history = await self._load_conversation_history(
|
||
conversation_id=conversation_id,
|
||
max_history=agent_config.memory.get("max_history", 10)
|
||
)
|
||
|
||
# 7. 知识库检索
|
||
context = None
|
||
|
||
# 8. 发送开始事件
|
||
yield self._format_sse_event("start", {
|
||
"conversation_id": conversation_id,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
memory_config_ = agent_config.memory
|
||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||
|
||
# 9. 流式调用 Agent
|
||
full_content = ""
|
||
total_tokens = 0
|
||
async for chunk in agent.chat_stream(
|
||
message=message,
|
||
history=history,
|
||
context=context,
|
||
end_user_id=user_id,
|
||
config_id=config_id,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id,
|
||
memory_flag=memory_flag
|
||
):
|
||
if isinstance(chunk, int):
|
||
total_tokens = chunk
|
||
else:
|
||
full_content += chunk
|
||
# 发送消息块事件
|
||
yield self._format_sse_event("message", {
|
||
"content": chunk
|
||
})
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
if sub_agent:
|
||
yield self._format_sse_event("sub_usage", {
|
||
"total_tokens": total_tokens
|
||
})
|
||
|
||
# 10. 保存会话消息
|
||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||
await self._save_conversation_message(
|
||
conversation_id=conversation_id,
|
||
user_message=message,
|
||
assistant_message=full_content,
|
||
app_id=agent_config.app_id,
|
||
user_id=user_id,
|
||
meta_data={
|
||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||
}
|
||
)
|
||
|
||
# 11. 发送结束事件
|
||
yield self._format_sse_event("end", {
|
||
"conversation_id": conversation_id,
|
||
"elapsed_time": elapsed_time,
|
||
"message_length": len(full_content)
|
||
})
|
||
|
||
logger.info(
|
||
"流式试运行完成",
|
||
extra={
|
||
"model": model_config.name,
|
||
"elapsed_time": elapsed_time,
|
||
"message_length": len(full_content)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("流式 Agent 调用失败", extra={"error": str(e)}, exc_info=True)
|
||
# 发送错误事件
|
||
yield self._format_sse_event("error", {
|
||
"error": str(e),
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str:
|
||
"""格式化 SSE 事件
|
||
|
||
Args:
|
||
event: 事件类型
|
||
data: 事件数据
|
||
|
||
Returns:
|
||
str: SSE 格式的字符串
|
||
"""
|
||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
|
||
"""获取模型的 API Key
|
||
|
||
Args:
|
||
model_config_id: 模型配置ID
|
||
|
||
Returns:
|
||
Dict: 包含 model_name, api_key, api_base 的字典
|
||
|
||
Raises:
|
||
BusinessException: 当没有可用的 API Key 时
|
||
"""
|
||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||
# stmt = (
|
||
# select(ModelApiKey).join(
|
||
# ModelConfig, ModelApiKey.model_configs
|
||
# )
|
||
# .where(
|
||
# ModelConfig.id == model_config_id,
|
||
# ModelApiKey.is_active.is_(True)
|
||
# )
|
||
# .order_by(ModelApiKey.priority.desc())
|
||
# .limit(1)
|
||
# )
|
||
#
|
||
# api_key = self.db.scalars(stmt).first()
|
||
api_key = api_keys[0] if api_keys else None
|
||
|
||
if not api_key:
|
||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||
|
||
return {
|
||
"model_name": api_key.model_name,
|
||
"provider": api_key.provider,
|
||
"api_key": api_key.api_key,
|
||
"api_base": api_key.api_base
|
||
}
|
||
|
||
async def _ensure_conversation(
|
||
self,
|
||
conversation_id: Optional[str],
|
||
app_id: uuid.UUID,
|
||
workspace_id: uuid.UUID,
|
||
user_id: Optional[str]
|
||
) -> str:
|
||
"""确保会话存在(创建或验证)
|
||
|
||
Args:
|
||
conversation_id: 会话ID(可选)
|
||
app_id: 应用ID
|
||
workspace_id: 工作空间ID(必须)
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
str: 会话ID
|
||
|
||
Raises:
|
||
BusinessException: 当指定的会话不存在时
|
||
"""
|
||
from app.models import Conversation as ConversationModel
|
||
from app.schemas.conversation_schema import ConversationCreate
|
||
from app.services.conversation_service import ConversationService
|
||
|
||
conversation_service = ConversationService(self.db)
|
||
|
||
# 如果没有提供会话ID,创建新会话
|
||
if not conversation_id:
|
||
logger.info(
|
||
"创建新的草稿会话",
|
||
extra={"workspace_id": str(workspace_id)}
|
||
)
|
||
|
||
# 获取配置快照
|
||
config_snapshot = await self._get_config_snapshot(app_id)
|
||
|
||
# 创建新会话
|
||
new_conv_id = str(uuid.uuid4())
|
||
new_conversation = ConversationModel(
|
||
id=uuid.UUID(new_conv_id),
|
||
app_id=app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
is_draft=True,
|
||
title="草稿会话",
|
||
config_snapshot=config_snapshot
|
||
)
|
||
self.db.add(new_conversation)
|
||
self.db.commit()
|
||
self.db.refresh(new_conversation)
|
||
|
||
logger.info(
|
||
"创建草稿会话成功",
|
||
extra={
|
||
"conversation_id": new_conv_id,
|
||
"workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
|
||
return new_conv_id
|
||
|
||
# 如果提供了会话ID,验证其存在性和工作空间归属
|
||
try:
|
||
conv_uuid = uuid.UUID(conversation_id)
|
||
conversation = conversation_service.get_conversation(conv_uuid)
|
||
|
||
# 验证会话属于当前工作空间
|
||
if conversation.workspace_id != workspace_id:
|
||
logger.warning(
|
||
"会话不属于当前工作空间",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"conversation_workspace_id": str(conversation.workspace_id),
|
||
"current_workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
raise BusinessException(
|
||
"会话不属于当前工作空间",
|
||
BizCode.PERMISSION_DENIED
|
||
)
|
||
|
||
logger.debug(
|
||
"使用现有会话",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
return conversation_id
|
||
except BusinessException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(
|
||
"会话不存在或无效",
|
||
extra={"conversation_id": conversation_id, "error": str(e)}
|
||
)
|
||
raise BusinessException(
|
||
f"会话不存在: {conversation_id}",
|
||
BizCode.NOT_FOUND,
|
||
cause=e
|
||
)
|
||
|
||
async def _load_conversation_history(
|
||
self,
|
||
conversation_id: str,
|
||
max_history: int = 10
|
||
) -> List[Dict[str, str]]:
|
||
"""加载会话历史消息
|
||
|
||
Args:
|
||
conversation_id: 会话ID
|
||
max_history: 最大历史消息数量
|
||
|
||
Returns:
|
||
List[Dict]: 历史消息列表
|
||
"""
|
||
try:
|
||
from app.services.conversation_service import ConversationService
|
||
|
||
conversation_service = ConversationService(self.db)
|
||
history = conversation_service.get_conversation_history(
|
||
conversation_id=uuid.UUID(conversation_id),
|
||
max_history=max_history
|
||
)
|
||
|
||
logger.debug(
|
||
"加载会话历史",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"max_history": max_history,
|
||
"loaded_count": len(history)
|
||
}
|
||
)
|
||
|
||
return history
|
||
|
||
except Exception as e:
|
||
# 新会话没有历史记录是正常的
|
||
logger.debug("加载会话历史失败(可能是新会话)", extra={"error": str(e)})
|
||
return []
|
||
|
||
async def _save_conversation_message(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
assistant_message: str,
|
||
meta_data: dict,
|
||
app_id: Optional[uuid.UUID] = None,
|
||
user_id: Optional[str] = None
|
||
) -> None:
|
||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||
|
||
Args:
|
||
conversation_id: 会话ID
|
||
user_message: 用户消息
|
||
assistant_message: AI 回复消息
|
||
app_id: 应用ID(未使用,保留用于兼容性)
|
||
user_id: 用户ID(未使用,保留用于兼容性)
|
||
meta_data: token消耗
|
||
"""
|
||
try:
|
||
from app.services.conversation_service import ConversationService
|
||
|
||
conversation_service = ConversationService(self.db)
|
||
conv_uuid = uuid.UUID(conversation_id)
|
||
|
||
# 保存消息(会话已经存在)
|
||
# 保存用户消息
|
||
conversation_service.add_message(
|
||
conversation_id=conv_uuid,
|
||
role="user",
|
||
content=user_message
|
||
)
|
||
# 保存助手消息
|
||
conversation_service.add_message(
|
||
conversation_id=conv_uuid,
|
||
role="assistant",
|
||
content=assistant_message,
|
||
meta_data=meta_data
|
||
)
|
||
|
||
logger.debug(
|
||
"保存会话消息",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"user_message_length": len(user_message),
|
||
"assistant_message_length": len(assistant_message)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning("保存会话消息失败", extra={"error": str(e)})
|
||
|
||
async def _get_config_snapshot(self, app_id: uuid.UUID) -> Dict[str, Any]:
|
||
"""获取当前配置快照
|
||
|
||
Args:
|
||
app_id: 应用ID
|
||
|
||
Returns:
|
||
Dict: 配置快照
|
||
"""
|
||
try:
|
||
from app.models import AgentConfig, ModelConfig
|
||
|
||
# 获取 Agent 配置
|
||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||
agent_cfg = self.db.scalars(stmt).first()
|
||
|
||
if not agent_cfg:
|
||
return {}
|
||
|
||
# 获取模型配置
|
||
model_config = None
|
||
if agent_cfg.default_model_config_id:
|
||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||
|
||
# 构建快照(确保所有值都可序列化)
|
||
def safe_serialize(value):
|
||
"""安全序列化值"""
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, (str, int, float, bool)):
|
||
return value
|
||
if isinstance(value, (dict, list)):
|
||
return value
|
||
# 对于 Pydantic 模型或其他对象,尝试转换为字典
|
||
if hasattr(value, 'dict'):
|
||
return value.dict()
|
||
if hasattr(value, '__dict__'):
|
||
return value.__dict__
|
||
return str(value)
|
||
|
||
snapshot = {
|
||
"agent_config": {
|
||
"system_prompt": agent_cfg.system_prompt,
|
||
"model_parameters": safe_serialize(agent_cfg.model_parameters),
|
||
"knowledge_retrieval": safe_serialize(agent_cfg.knowledge_retrieval),
|
||
"memory": safe_serialize(agent_cfg.memory),
|
||
"variables": safe_serialize(agent_cfg.variables),
|
||
"tools": safe_serialize(agent_cfg.tools)
|
||
},
|
||
"model_config": {
|
||
"model_name": model_config.name if model_config else None,
|
||
"provider": model_config.provider if model_config else None,
|
||
"type": model_config.type if model_config else None
|
||
} if model_config else None,
|
||
"snapshot_time": datetime.datetime.now().isoformat()
|
||
}
|
||
|
||
return snapshot
|
||
|
||
except Exception as e:
|
||
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
|
||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
|
||
return {}
|
||
|
||
def _replace_variables(
|
||
self,
|
||
text: str,
|
||
values: Dict[str, Any],
|
||
definitions: List[Dict[str, Any]]
|
||
) -> str:
|
||
"""替换文本中的变量
|
||
|
||
Args:
|
||
text: 原始文本
|
||
values: 变量值
|
||
definitions: 变量定义
|
||
|
||
Returns:
|
||
str: 替换后的文本
|
||
"""
|
||
result = text
|
||
|
||
# 创建变量定义映射
|
||
var_defs = {var["name"]: var for var in definitions}
|
||
|
||
for var_name, var_value in values.items():
|
||
# 检查变量是否在定义中
|
||
if var_name not in var_defs:
|
||
logger.warning(f"未定义的变量: {var_name}")
|
||
continue
|
||
|
||
# 替换变量(支持多种格式)
|
||
placeholders = [
|
||
f"{{{{{var_name}}}}}", # {{var_name}}
|
||
f"{{{var_name}}}", # {var_name}
|
||
f"${{{var_name}}}", # ${var_name}
|
||
]
|
||
|
||
for placeholder in placeholders:
|
||
if placeholder in result:
|
||
result = result.replace(placeholder, str(var_value))
|
||
|
||
return result
|
||
|
||
# ==================== 多模型对比试运行 ====================
|
||
|
||
async def run_compare(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
models: List[Dict[str, Any]],
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
parallel: bool = True,
|
||
timeout: int = 60,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True,
|
||
memory: bool = True,
|
||
) -> Dict[str, Any]:
|
||
"""多模型对比试运行
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID
|
||
conversation_id: 会话ID
|
||
user_id: 用户ID
|
||
variables: 变量参数
|
||
parallel: 是否并行执行
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
Dict: 对比结果
|
||
"""
|
||
logger.info(
|
||
"多模型对比试运行",
|
||
extra={
|
||
"model_count": len(models),
|
||
"parallel": parallel
|
||
}
|
||
)
|
||
|
||
async def run_single_model(model_info):
|
||
"""运行单个模型"""
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 临时修改参数(不使用 deepcopy 避免 SQLAlchemy 会话问题)
|
||
original_params = agent_config.model_parameters
|
||
agent_config.model_parameters = model_info["parameters"]
|
||
|
||
# 使用模型自己的 conversation_id,如果没有则使用全局的
|
||
model_conversation_id = model_info.get("conversation_id") or conversation_id
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self.run(
|
||
agent_config=agent_config,
|
||
model_config=model_info["model_config"],
|
||
message=message,
|
||
workspace_id=workspace_id,
|
||
conversation_id=model_conversation_id,
|
||
user_id=user_id,
|
||
variables=variables,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id,
|
||
web_search=web_search,
|
||
memory=memory
|
||
),
|
||
timeout=timeout
|
||
)
|
||
finally:
|
||
# 恢复原始参数
|
||
agent_config.model_parameters = original_params
|
||
|
||
elapsed = time.time() - start_time
|
||
usage = result.get("usage", {})
|
||
|
||
return {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_info["label"],
|
||
"conversation_id":result['conversation_id'],
|
||
"parameters_used": model_info["parameters"],
|
||
"message": result.get("message"),
|
||
"usage": usage,
|
||
"elapsed_time": elapsed,
|
||
"tokens_per_second": (
|
||
usage.get("completion_tokens", 0) / elapsed
|
||
if elapsed > 0 and usage.get("completion_tokens") else None
|
||
),
|
||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||
"error": None
|
||
}
|
||
|
||
except TimeoutError:
|
||
logger.warning(
|
||
"模型运行超时",
|
||
extra={
|
||
"model_config_id": str(model_info["model_config_id"]),
|
||
"timeout": timeout
|
||
}
|
||
)
|
||
return {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"conversation_id": conversation_id,
|
||
"label": model_info["label"],
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": timeout,
|
||
"error": f"执行超时({timeout}秒)"
|
||
}
|
||
except Exception as e:
|
||
logger.error(
|
||
"模型运行失败",
|
||
extra={
|
||
"model_config_id": str(model_info["model_config_id"]),
|
||
"error": str(e)
|
||
}
|
||
)
|
||
return {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_info["label"],
|
||
"conversation_id": conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": 0,
|
||
"error": str(e)
|
||
}
|
||
|
||
# 执行所有模型(并行或串行)
|
||
if parallel:
|
||
logger.debug(f"并行执行 {len(models)} 个模型")
|
||
results = await asyncio.gather(
|
||
*[run_single_model(m) for m in models],
|
||
return_exceptions=False
|
||
)
|
||
else:
|
||
logger.debug(f"串行执行 {len(models)} 个模型")
|
||
results = []
|
||
for model_info in models:
|
||
result = await run_single_model(model_info)
|
||
results.append(result)
|
||
|
||
# 统计分析
|
||
successful = [r for r in results if not r.get("error")]
|
||
failed = [r for r in results if r.get("error")]
|
||
|
||
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
|
||
cheapest = min(
|
||
successful,
|
||
key=lambda x: x.get("cost_estimate") or float("inf")
|
||
) if successful else None
|
||
|
||
logger.info(
|
||
"多模型对比完成",
|
||
extra={
|
||
"successful": len(successful),
|
||
"failed": len(failed),
|
||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||
}
|
||
)
|
||
|
||
return {
|
||
"results": results,
|
||
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
||
"successful_count": len(successful),
|
||
"failed_count": len(failed),
|
||
"fastest_model": fastest["label"] if fastest else None,
|
||
"cheapest_model": cheapest["label"] if cheapest else None
|
||
}
|
||
|
||
def _estimate_cost(self, usage: Dict[str, Any], model_config) -> Optional[float]:
|
||
"""估算成本
|
||
|
||
Args:
|
||
usage: Token 使用情况
|
||
model_config: 模型配置
|
||
|
||
Returns:
|
||
Optional[float]: 估算成本(美元)
|
||
"""
|
||
if not usage:
|
||
return None
|
||
|
||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||
completion_tokens = usage.get("completion_tokens", 0)
|
||
|
||
# 简化成本估算:暂时返回 None
|
||
# TODO: 实现基于模型名称或配置的成本估算
|
||
# 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段
|
||
return None
|
||
|
||
def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> AgentConfig:
|
||
"""创建一个带有覆盖参数的 agent_config(浅拷贝,只修改 model_parameters)
|
||
|
||
Args:
|
||
agent_config: 原始 Agent 配置
|
||
parameters: 要覆盖的参数
|
||
|
||
Returns:
|
||
AgentConfig: 修改后的配置(注意:这是同一个对象,只是临时修改了 model_parameters)
|
||
"""
|
||
# 保存原始参数
|
||
original_params = agent_config.model_parameters
|
||
# 设置新参数
|
||
agent_config.model_parameters = parameters
|
||
return agent_config, original_params
|
||
|
||
async def run_compare_stream(
|
||
self,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
models: List[Dict[str, Any]],
|
||
message: str,
|
||
workspace_id: uuid.UUID,
|
||
conversation_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
variables: Optional[Dict[str, Any]] = None,
|
||
storage_type: Optional[str] = None,
|
||
user_rag_memory_id: Optional[str] = None,
|
||
web_search: bool = True,
|
||
memory: bool = True,
|
||
parallel: bool = True,
|
||
timeout: int = 60
|
||
) -> AsyncGenerator[str, None]:
|
||
"""多模型对比试运行(流式返回)
|
||
|
||
参考 run_compare 的实现,支持并行或串行执行
|
||
|
||
Args:
|
||
agent_config: Agent 配置
|
||
models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id
|
||
message: 用户消息
|
||
workspace_id: 工作空间ID
|
||
conversation_id: 会话ID
|
||
user_id: 用户ID
|
||
variables: 变量参数
|
||
storage_type: 存储类型
|
||
user_rag_memory_id: RAG 记忆 ID
|
||
web_search: 是否启用网络搜索
|
||
memory: 是否启用记忆
|
||
parallel: 是否并行执行
|
||
timeout: 超时时间(秒)
|
||
|
||
Yields:
|
||
str: SSE 格式的事件数据
|
||
"""
|
||
logger.info(
|
||
"多模型对比流式试运行",
|
||
extra={"model_count": len(models), "parallel": parallel}
|
||
)
|
||
|
||
# 发送开始事件
|
||
yield self._format_sse_event("compare_start", {
|
||
"conversation_id": conversation_id,
|
||
"model_count": len(models),
|
||
"parallel": parallel,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
results = []
|
||
|
||
async def run_single_model_stream(idx: int, model_info: Dict[str, Any], event_queue: asyncio.Queue):
|
||
"""运行单个模型(流式)并将事件放入队列"""
|
||
model_label = model_info["label"]
|
||
model_config_id = str(model_info["model_config_id"])
|
||
# 使用模型自己的 conversation_id,如果没有则使用全局的
|
||
model_conversation_id = model_info.get("conversation_id") or conversation_id
|
||
|
||
try:
|
||
# 发送模型开始事件
|
||
await event_queue.put(self._format_sse_event("model_start", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
start_time = time.time()
|
||
full_content = ""
|
||
returned_conversation_id = model_conversation_id
|
||
|
||
# 临时修改参数
|
||
original_params = agent_config.model_parameters
|
||
agent_config.model_parameters = model_info["parameters"]
|
||
|
||
try:
|
||
# 流式调用单个模型
|
||
async for event_str in self.run_stream(
|
||
agent_config=agent_config,
|
||
model_config=model_info["model_config"],
|
||
message=message,
|
||
workspace_id=workspace_id,
|
||
conversation_id=model_conversation_id,
|
||
user_id=user_id,
|
||
variables=variables,
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id,
|
||
web_search=web_search,
|
||
memory=memory
|
||
):
|
||
# 解析原始事件
|
||
try:
|
||
lines = event_str.strip().split('\n')
|
||
event_type = None
|
||
event_data = None
|
||
|
||
for line in lines:
|
||
if line.startswith('event: '):
|
||
event_type = line[7:].strip()
|
||
elif line.startswith('data: '):
|
||
event_data = json.loads(line[6:])
|
||
|
||
# 从 start 事件中获取实际的 conversation_id
|
||
if event_type == "start" and event_data:
|
||
conv_id = event_data.get("conversation_id")
|
||
if conv_id:
|
||
returned_conversation_id = conv_id
|
||
|
||
# 累积消息内容
|
||
if event_type == "message" and event_data:
|
||
chunk = event_data.get("content", "")
|
||
full_content += chunk
|
||
|
||
# 转发消息块事件(带模型标识)
|
||
await event_queue.put(self._format_sse_event("model_message", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"content": chunk
|
||
}))
|
||
except Exception as e:
|
||
logger.warning(f"解析流式事件失败: {e}")
|
||
finally:
|
||
# 恢复原始参数
|
||
agent_config.model_parameters = original_params
|
||
|
||
elapsed = time.time() - start_time
|
||
|
||
# 构建结果(参考 run_compare)
|
||
result = {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"message": full_content,
|
||
"elapsed_time": elapsed,
|
||
"error": None
|
||
}
|
||
|
||
# 发送模型完成事件
|
||
await event_queue.put(self._format_sse_event("model_end", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": returned_conversation_id,
|
||
"elapsed_time": elapsed,
|
||
"message_length": len(full_content),
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
return result
|
||
|
||
except TimeoutError:
|
||
logger.warning(f"模型运行超时: {model_label}")
|
||
result = {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": timeout,
|
||
"error": f"执行超时({timeout}秒)"
|
||
}
|
||
|
||
await event_queue.put(self._format_sse_event("model_error", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"error": result["error"],
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型运行失败: {model_label}, error: {e}")
|
||
result = {
|
||
"model_config_id": model_info["model_config_id"],
|
||
"model_name": model_info["model_config"].name,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"parameters_used": model_info["parameters"],
|
||
"elapsed_time": 0,
|
||
"error": str(e)
|
||
}
|
||
|
||
await event_queue.put(self._format_sse_event("model_error", {
|
||
"model_index": idx,
|
||
"model_config_id": model_config_id,
|
||
"label": model_label,
|
||
"conversation_id": model_conversation_id,
|
||
"error": str(e),
|
||
"timestamp": time.time()
|
||
}))
|
||
|
||
return result
|
||
|
||
if parallel:
|
||
# 并行执行所有模型(参考 run_compare)
|
||
logger.debug(f"并行执行 {len(models)} 个模型(流式)")
|
||
|
||
# 创建事件队列
|
||
event_queue = asyncio.Queue()
|
||
|
||
# 启动所有模型的并行任务
|
||
tasks = [
|
||
asyncio.create_task(run_single_model_stream(idx, model_info, event_queue))
|
||
for idx, model_info in enumerate(models)
|
||
]
|
||
|
||
# 持续从队列中取出事件并转发
|
||
completed_tasks = set()
|
||
while len(completed_tasks) < len(tasks):
|
||
try:
|
||
# 尝试从队列获取事件
|
||
event = await asyncio.wait_for(event_queue.get(), timeout=0.1)
|
||
yield event
|
||
except TimeoutError:
|
||
# 检查是否有任务完成
|
||
for task in tasks:
|
||
if task.done() and task not in completed_tasks:
|
||
completed_tasks.add(task)
|
||
try:
|
||
result = await task
|
||
if result:
|
||
results.append(result)
|
||
except Exception as e:
|
||
logger.error(f"获取任务结果失败: {e}")
|
||
continue
|
||
|
||
# 清空队列中剩余的事件
|
||
while not event_queue.empty():
|
||
try:
|
||
event = event_queue.get_nowait()
|
||
yield event
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
else:
|
||
# 串行执行每个模型(参考 run_compare)
|
||
logger.debug(f"串行执行 {len(models)} 个模型(流式)")
|
||
|
||
for idx, model_info in enumerate(models):
|
||
# 创建临时队列用于单个模型
|
||
event_queue = asyncio.Queue()
|
||
|
||
# 运行单个模型
|
||
result = await run_single_model_stream(idx, model_info, event_queue)
|
||
if result:
|
||
results.append(result)
|
||
|
||
# 转发该模型的所有事件
|
||
while not event_queue.empty():
|
||
try:
|
||
event = event_queue.get_nowait()
|
||
yield event
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
# 统计分析(参考 run_compare)
|
||
successful = [r for r in results if not r.get("error")]
|
||
failed = [r for r in results if r.get("error")]
|
||
|
||
fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None
|
||
cheapest = min(
|
||
successful,
|
||
key=lambda x: x.get("cost_estimate") or float("inf")
|
||
) if successful else None
|
||
|
||
# 构建结果摘要(包含完整的 message)
|
||
results_summary = []
|
||
for r in results:
|
||
results_summary.append({
|
||
"model_config_id": str(r["model_config_id"]),
|
||
"model_name": r["model_name"],
|
||
"label": r["label"],
|
||
"conversation_id": r.get("conversation_id"),
|
||
"message": r.get("message"), # 包含完整消息
|
||
"elapsed_time": r.get("elapsed_time", 0),
|
||
"error": r.get("error")
|
||
})
|
||
|
||
# 发送对比完成事件(参考 run_compare 的返回格式)
|
||
yield self._format_sse_event("compare_end", {
|
||
"conversation_id": conversation_id,
|
||
"results": results_summary, # 包含完整结果
|
||
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
||
"successful_count": len(successful),
|
||
"failed_count": len(failed),
|
||
"fastest_model": fastest["label"] if fastest else None,
|
||
"cheapest_model": cheapest["label"] if cheapest else None,
|
||
"timestamp": time.time()
|
||
})
|
||
|
||
logger.info(
|
||
"多模型对比流式完成",
|
||
extra={
|
||
"successful": len(successful),
|
||
"failed": len(failed),
|
||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||
}
|
||
)
|
||
|
||
|
||
async def draft_run(
|
||
db: Session,
|
||
*,
|
||
agent_config: AgentConfig,
|
||
model_config: ModelConfig,
|
||
message: str,
|
||
user_id: Optional[str] = None,
|
||
kb_ids: Optional[List[str]] = None,
|
||
similarity_threshold: float = 0.7,
|
||
top_k: int = 3
|
||
) -> Dict[str, Any]:
|
||
"""试运行 Agent(便捷函数)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
agent_config: Agent 配置
|
||
model_config: 模型配置
|
||
message: 用户消息
|
||
user_id: 用户ID
|
||
kb_ids: 知识库ID列表
|
||
similarity_threshold: 相似度阈值
|
||
top_k: 检索返回的文档数量
|
||
|
||
Returns:
|
||
Dict: 包含 AI 回复和元数据的字典
|
||
"""
|
||
service = DraftRunService(db)
|
||
return await service.run(
|
||
agent_config=agent_config,
|
||
model_config=model_config,
|
||
message=message,
|
||
user_id=user_id,
|
||
kb_ids=kb_ids,
|
||
similarity_threshold=similarity_threshold,
|
||
top_k=top_k
|
||
)
|
||
|