Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-02-06 18:56:35 +08:00
8 changed files with 229 additions and 195 deletions

View File

@@ -802,7 +802,8 @@ async def draft_run_compare(
web_search=True, web_search=True,
memory=True, memory=True,
parallel=payload.parallel, parallel=payload.parallel,
timeout=payload.timeout or 60 timeout=payload.timeout or 60,
files=payload.files
): ):
yield event yield event
@@ -906,14 +907,14 @@ def get_app_statistics(
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db) stats_service = AppStatisticsService(db)
result = stats_service.get_app_statistics( result = stats_service.get_app_statistics(
app_id=app_id, app_id=app_id,
workspace_id=workspace_id, workspace_id=workspace_id,
start_date=start_date, start_date=start_date,
end_date=end_date end_date=end_date
) )
return success(data=result) return success(data=result)
@@ -940,11 +941,11 @@ def get_workspace_api_statistics(
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db) stats_service = AppStatisticsService(db)
result = stats_service.get_workspace_api_statistics( result = stats_service.get_workspace_api_statistics(
workspace_id=workspace_id, workspace_id=workspace_id,
start_date=start_date, start_date=start_date,
end_date=end_date end_date=end_date
) )
return success(data=result) return success(data=result)

View File

@@ -104,14 +104,18 @@ async def start_workspace_reflection(
) -> dict: ) -> dict:
"""启动工作空间中所有匹配应用的反思功能""" """启动工作空间中所有匹配应用的反思功能"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
reflection_service = MemoryReflectionService(db)
try: try:
api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}") api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}")
service = WorkspaceAppService(db) # 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
result = service.get_workspace_apps_detailed(workspace_id) from app.db import get_db_context
with get_db_context() as query_db:
service = WorkspaceAppService(query_db)
result = service.get_workspace_apps_detailed(workspace_id)
reflection_results = [] reflection_results = []
for data in result['apps_detailed_info']: for data in result['apps_detailed_info']:
# 跳过没有配置的应用 # 跳过没有配置的应用
if not data['memory_configs']: if not data['memory_configs']:
@@ -133,33 +137,36 @@ async def start_workspace_reflection(
api_logger.debug(f"配置 {config_id_str} 没有匹配的release") api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue continue
# 为每个用户执行反思 # 为每个用户执行反思 - 使用独立的数据库会话
for user in end_users: for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}") api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
try: # 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
reflection_result = await reflection_service.start_text_reflection( with get_db_context() as user_db:
config_data=config, try:
end_user_id=user['id'] reflection_service = MemoryReflectionService(user_db)
) reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({ reflection_results.append({
"app_id": data['id'], "app_id": data['id'],
"config_id": config_id_str, "config_id": config_id_str,
"end_user_id": user['id'], "end_user_id": user['id'],
"reflection_result": reflection_result "reflection_result": reflection_result
}) })
except Exception as e: except Exception as e:
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}") api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
reflection_results.append({ reflection_results.append({
"app_id": data['id'], "app_id": data['id'],
"config_id": config_id_str, "config_id": config_id_str,
"end_user_id": user['id'], "end_user_id": user['id'],
"reflection_result": { "reflection_result": {
"status": "错误", "status": "错误",
"message": f"反思失败: {str(e)}" "message": f"反思失败: {str(e)}"
} }
}) })
return success(data=reflection_results, msg="反思配置成功") return success(data=reflection_results, msg="反思配置成功")

View File

@@ -462,8 +462,8 @@ class ReflectionEngine:
List[Any]: 反思数据列表 List[Any]: 反思数据列表
""" """
print("=== 获取反思数据 ===")
print(f" 主机ID: {host_id}")
if self.config.reflexion_range == ReflectionRange.PARTIAL: if self.config.reflexion_range == ReflectionRange.PARTIAL:
neo4j_query = neo4j_query_part.format(host_id) neo4j_query = neo4j_query_part.format(host_id)
neo4j_statement = neo4j_statement_part.format(host_id) neo4j_statement = neo4j_statement_part.format(host_id)

View File

@@ -488,7 +488,7 @@ class DraftRunCompareRequest(BaseModel):
max_length=5, max_length=5,
description="要对比的模型列表1-5个" description="要对比的模型列表1-5个"
) )
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
parallel: bool = Field(True, description="是否并行执行") parallel: bool = Field(True, description="是否并行执行")
stream: bool = Field(False, description="是否流式返回") stream: bool = Field(False, description="是否流式返回")
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)") timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")

View File

@@ -16,26 +16,26 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig from app.models import AgentConfig, ModelConfig
from app.repositories.model_repository import ModelApiKeyRepository
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service from app.services import task_service
from app.services.langchain_tool_server import Search from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.tool_service import ToolService
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.core.agent.agent_middleware import AgentMiddleware from app.services.tool_service import ToolService
logger = get_business_logger() logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel): class KnowledgeRetrievalInput(BaseModel):
"""知识库检索工具输入参数""" """知识库检索工具输入参数"""
query: str = Field(description="需要检索的问题或关键词") query: str = Field(description="需要检索的问题或关键词")
@@ -48,9 +48,12 @@ class WebSearchInput(BaseModel):
class LongTermMemoryInput(BaseModel): class LongTermMemoryInput(BaseModel):
"""长期记忆工具输入参数""" """长期记忆工具输入参数"""
question: str = Field(description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") question: str = Field(
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,user_rag_memory_id: Optional[str] = None):
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None):
"""创建记忆工具, """创建记忆工具,
@@ -66,6 +69,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content # 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None) config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
@tool(args_schema=LongTermMemoryInput) @tool(args_schema=LongTermMemoryInput)
def long_term_memory(question: str) -> str: def long_term_memory(question: str) -> str:
""" """
@@ -133,6 +137,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
except Exception as e: except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"记忆检索失败: {str(e)}" return f"记忆检索失败: {str(e)}"
return long_term_memory return long_term_memory
@@ -179,7 +184,7 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
return web_search_tool return web_search_tool
def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id): def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args: Args:
@@ -189,6 +194,7 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
检索到的相关知识内容 检索到的相关知识内容
""" """
logger.info(f"创建知识库检索工具,用户:{user_id}") logger.info(f"创建知识库检索工具,用户:{user_id}")
@tool(args_schema=KnowledgeRetrievalInput) @tool(args_schema=KnowledgeRetrievalInput)
def knowledge_retrieval_tool(query: str) -> str: def knowledge_retrieval_tool(query: str) -> str:
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
@@ -200,7 +206,6 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
检索到的相关知识内容 检索到的相关知识内容
""" """
try: try:
retrieve_chunks_result = knowledge_retrieval(query, kb_config) retrieve_chunks_result = knowledge_retrieval(query, kb_config)
@@ -226,6 +231,7 @@ def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id):
return knowledge_retrieval_tool return knowledge_retrieval_tool
class DraftRunService: class DraftRunService:
"""试运行服务类""" """试运行服务类"""
@@ -238,21 +244,21 @@ class DraftRunService:
self.db = db self.db = db
async def run( async def run(
self, self,
*, *,
agent_config: AgentConfig, agent_config: AgentConfig,
model_config: ModelConfig, model_config: ModelConfig,
message: str, message: str,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
web_search: bool = True, web_search: bool = True,
memory: bool = True, memory: bool = True,
sub_agent: bool = False, sub_agent: bool = False,
files: Optional[List[FileInput]] = None # 新增:多模态文件 files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""执行试运行(使用 LangChain Agent """执行试运行(使用 LangChain Agent
@@ -268,9 +274,9 @@ class DraftRunService:
Returns: Returns:
Dict: 包含 AI 回复和元数据的字典 Dict: 包含 AI 回复和元数据的字典
""" """
memory_flag=False memory_flag = False
print('===========',storage_type) print('===========', storage_type)
print(user_id) print(user_id)
if variables == None: variables = {} if variables == None: variables = {}
@@ -296,8 +302,7 @@ class DraftRunService:
agent_config=agent_config agent_config=agent_config
) )
items_params = variables
items_params=variables
system_prompt = render_prompt_message( system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误 agent_config.system_prompt, # 修正拼写错误
PromptMessageRole.USER, PromptMessageRole.USER,
@@ -306,7 +311,7 @@ class DraftRunService:
# 3. 处理系统提示词(支持变量替换) # 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
print('系统提示词:',system_prompt) print('系统提示词:', system_prompt)
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
@@ -318,7 +323,7 @@ class DraftRunService:
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
if hasattr(agent_config, 'tools') and agent_config.tools: if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools: for tool_config in agent_config.tools:
print("+"*50) print("+" * 50)
print(f"agent_config:{agent_config}") print(f"agent_config:{agent_config}")
print(f"tool_config:{tool_config}") print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
@@ -358,7 +363,8 @@ class DraftRunService:
# 应用动态过滤 # 应用动态过滤
if skill_configs: if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具") logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts( active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs activated_skill_ids, skill_configs
@@ -372,7 +378,7 @@ class DraftRunService:
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids: if kb_ids:
# 创建知识库检索工具 # 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config,kb_ids,user_id) kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool) tools.append(kb_tool)
logger.debug( logger.debug(
@@ -386,12 +392,13 @@ class DraftRunService:
# 添加长期记忆工具 # 添加长期记忆工具
if memory: if memory:
if agent_config.memory and agent_config.memory.get("enabled"): if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag=True memory_flag = True
memory_config = agent_config.memory memory_config = agent_config.memory
if user_id: if user_id:
# 创建长期记忆工具 # 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id) memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool) tools.append(memory_tool)
logger.debug( logger.debug(
@@ -452,7 +459,7 @@ class DraftRunService:
} }
) )
memory_config_= agent_config.memory memory_config_ = agent_config.memory
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content # 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
@@ -518,21 +525,21 @@ class DraftRunService:
raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
async def run_stream( async def run_stream(
self, self,
*, *,
agent_config: AgentConfig, agent_config: AgentConfig,
model_config: ModelConfig, model_config: ModelConfig,
message: str, message: str,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
web_search: bool = True, # 布尔类型默认值 web_search: bool = True, # 布尔类型默认值
memory: bool = True, # 布尔类型默认值 memory: bool = True, # 布尔类型默认值
sub_agent: bool = False, # 是否是作为子Agent运行 sub_agent: bool = False, # 是否是作为子Agent运行
files: Optional[List[FileInput]] = None # 新增:多模态文件 files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""执行试运行(流式返回,使用 LangChain Agent """执行试运行(流式返回,使用 LangChain Agent
@@ -549,8 +556,8 @@ class DraftRunService:
Yields: Yields:
str: SSE 格式的事件数据 str: SSE 格式的事件数据
""" """
memory_flag=False memory_flag = False
if variables==None:variables={} if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
@@ -566,7 +573,7 @@ class DraftRunService:
agent_config=agent_config agent_config=agent_config
) )
items_params=variables items_params = variables
system_prompt = render_prompt_message( system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误 agent_config.system_prompt, # 修正拼写错误
@@ -626,14 +633,14 @@ class DraftRunService:
# 应用动态过滤 # 应用动态过滤
if skill_configs: if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具") logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts( active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs activated_skill_ids, skill_configs
) )
system_prompt = f"{system_prompt}\n\n{active_prompts}" system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具 # 添加知识库检索工具
if agent_config.knowledge_retrieval: if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval kb_config = agent_config.knowledge_retrieval
@@ -654,11 +661,12 @@ class DraftRunService:
# 添加长期记忆工具 # 添加长期记忆工具
if memory: if memory:
if agent_config.memory and agent_config.memory.get("enabled"): if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag= True memory_flag = True
memory_config = agent_config.memory memory_config = agent_config.memory
if user_id: if user_id:
# 创建长期记忆工具 # 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id) memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool) tools.append(memory_tool)
logger.debug( logger.debug(
@@ -724,15 +732,15 @@ class DraftRunService:
full_content = "" full_content = ""
total_tokens = 0 total_tokens = 0
async for chunk in agent.chat_stream( async for chunk in agent.chat_stream(
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id, end_user_id=user_id,
config_id=config_id, config_id=config_id,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag, memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
total_tokens = chunk total_tokens = chunk
@@ -749,8 +757,8 @@ class DraftRunService:
if sub_agent: if sub_agent:
yield self._format_sse_event("sub_usage", { yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens "total_tokens": total_tokens
}) })
# 10. 保存会话消息 # 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
@@ -842,11 +850,11 @@ class DraftRunService:
} }
async def _ensure_conversation( async def _ensure_conversation(
self, self,
conversation_id: Optional[str], conversation_id: Optional[str],
app_id: uuid.UUID, app_id: uuid.UUID,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
user_id: Optional[str] user_id: Optional[str]
) -> str: ) -> str:
"""确保会话存在(创建或验证) """确保会话存在(创建或验证)
@@ -863,7 +871,6 @@ class DraftRunService:
BusinessException: 当指定的会话不存在时 BusinessException: 当指定的会话不存在时
""" """
from app.models import Conversation as ConversationModel from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db) conversation_service = ConversationService(self.db)
@@ -945,9 +952,9 @@ class DraftRunService:
) )
async def _load_conversation_history( async def _load_conversation_history(
self, self,
conversation_id: str, conversation_id: str,
max_history: int = 10 max_history: int = 10
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""加载会话历史消息 """加载会话历史消息
@@ -984,13 +991,13 @@ class DraftRunService:
return [] return []
async def _save_conversation_message( async def _save_conversation_message(
self, self,
conversation_id: str, conversation_id: str,
user_message: str, user_message: str,
assistant_message: str, assistant_message: str,
meta_data: dict, meta_data: dict,
app_id: Optional[uuid.UUID] = None, app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None user_id: Optional[str] = None
) -> None: ) -> None:
"""保存会话消息(会话已通过 _ensure_conversation 确保存在) """保存会话消息(会话已通过 _ensure_conversation 确保存在)
@@ -1100,10 +1107,10 @@ class DraftRunService:
return {} return {}
def _replace_variables( def _replace_variables(
self, self,
text: str, text: str,
values: Dict[str, Any], values: Dict[str, Any],
definitions: List[Dict[str, Any]] definitions: List[Dict[str, Any]]
) -> str: ) -> str:
"""替换文本中的变量 """替换文本中的变量
@@ -1129,8 +1136,8 @@ class DraftRunService:
# 替换变量(支持多种格式) # 替换变量(支持多种格式)
placeholders = [ placeholders = [
f"{{{{{var_name}}}}}", # {{var_name}} f"{{{{{var_name}}}}}", # {{var_name}}
f"{{{var_name}}}", # {var_name} f"{{{var_name}}}", # {var_name}
f"${{{var_name}}}", # ${var_name} f"${{{var_name}}}", # ${var_name}
] ]
for placeholder in placeholders: for placeholder in placeholders:
@@ -1142,21 +1149,22 @@ class DraftRunService:
# ==================== 多模型对比试运行 ==================== # ==================== 多模型对比试运行 ====================
async def run_compare( async def run_compare(
self, self,
*, *,
agent_config: AgentConfig, agent_config: AgentConfig,
models: List[Dict[str, Any]], models: List[Dict[str, Any]],
message: str, message: str,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
parallel: bool = True, parallel: bool = True,
timeout: int = 60, timeout: int = 60,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
web_search: bool = True, web_search: bool = True,
memory: bool = True, memory: bool = True,
files: list[FileInput] | None = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""多模型对比试运行 """多模型对比试运行
@@ -1206,7 +1214,8 @@ class DraftRunService:
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=web_search, web_search=web_search,
memory=memory memory=memory,
files=files
), ),
timeout=timeout timeout=timeout
) )
@@ -1221,7 +1230,7 @@ class DraftRunService:
"model_config_id": model_info["model_config_id"], "model_config_id": model_info["model_config_id"],
"model_name": model_info["model_config"].name, "model_name": model_info["model_config"].name,
"label": model_info["label"], "label": model_info["label"],
"conversation_id":result['conversation_id'], "conversation_id": result['conversation_id'],
"parameters_used": model_info["parameters"], "parameters_used": model_info["parameters"],
"message": result.get("message"), "message": result.get("message"),
"usage": usage, "usage": usage,
@@ -1349,21 +1358,22 @@ class DraftRunService:
return agent_config, original_params return agent_config, original_params
async def run_compare_stream( async def run_compare_stream(
self, self,
*, *,
agent_config: AgentConfig, agent_config: AgentConfig,
models: List[Dict[str, Any]], models: List[Dict[str, Any]],
message: str, message: str,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
web_search: bool = True, web_search: bool = True,
memory: bool = True, memory: bool = True,
parallel: bool = True, parallel: bool = True,
timeout: int = 60 timeout: int = 60,
files: list[FileInput] | None = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""多模型对比试运行(流式返回) """多模型对比试运行(流式返回)
@@ -1383,6 +1393,7 @@ class DraftRunService:
memory: 是否启用记忆 memory: 是否启用记忆
parallel: 是否并行执行 parallel: 是否并行执行
timeout: 超时时间(秒) timeout: 超时时间(秒)
files: 多模态文件
Yields: Yields:
str: SSE 格式的事件数据 str: SSE 格式的事件数据
@@ -1431,17 +1442,18 @@ class DraftRunService:
try: try:
# 流式调用单个模型 # 流式调用单个模型
async for event_str in self.run_stream( async for event_str in self.run_stream(
agent_config=agent_config, agent_config=agent_config,
model_config=model_info["model_config"], model_config=model_info["model_config"],
message=message, message=message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=model_conversation_id, conversation_id=model_conversation_id,
user_id=user_id, user_id=user_id,
variables=variables, variables=variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=web_search, web_search=web_search,
memory=memory memory=memory,
files=files
): ):
# 解析原始事件 # 解析原始事件
try: try:
@@ -1661,15 +1673,15 @@ class DraftRunService:
async def draft_run( async def draft_run(
db: Session, db: Session,
*, *,
agent_config: AgentConfig, agent_config: AgentConfig,
model_config: ModelConfig, model_config: ModelConfig,
message: str, message: str,
user_id: Optional[str] = None, user_id: Optional[str] = None,
kb_ids: Optional[List[str]] = None, kb_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.7, similarity_threshold: float = 0.7,
top_k: int = 3 top_k: int = 3
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""试运行 Agent便捷函数 """试运行 Agent便捷函数
@@ -1696,4 +1708,3 @@ async def draft_run(
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
top_k=top_k top_k=top_k
) )

View File

@@ -313,13 +313,16 @@ class MemoryAgentService:
start_time = time.time() start_time = time.time()
# Load configuration from database with workspace fallback # Load configuration from database with workspace fallback
# Use a separate database session to avoid transaction failures
try: try:
config_service = MemoryConfigService(db) from app.db import get_db_context
memory_config = config_service.load_memory_config( with get_db_context() as config_db:
config_id=config_id, config_service = MemoryConfigService(config_db)
workspace_id=workspace_id, memory_config = config_service.load_memory_config(
service_name="MemoryAgentService" config_id=config_id,
) workspace_id=workspace_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}") logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -454,12 +457,15 @@ class MemoryAgentService:
config_load_start = time.time() config_load_start = time.time()
try: try:
config_service = MemoryConfigService(db) # Use a separate database session to avoid transaction failures
memory_config = config_service.load_memory_config( from app.db import get_db_context
config_id=config_id, with get_db_context() as config_db:
workspace_id=workspace_id, config_service = MemoryConfigService(config_db)
service_name="MemoryAgentService" memory_config = config_service.load_memory_config(
) config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
config_load_time = time.time() - config_load_start config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}") logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:

View File

@@ -364,7 +364,8 @@ class MemoryReflectionService:
reflexion_range_value = config_data.get("reflexion_range") reflexion_range_value = config_data.get("reflexion_range")
if reflexion_range_value is None or reflexion_range_value == "": if reflexion_range_value is None or reflexion_range_value == "":
reflexion_range_value = "partial" reflexion_range_value = "partial"
# Map legacy/invalid values to valid enum values
# Map legacy/invalid values to valid enum values
reflexion_range_mapping = { reflexion_range_mapping = {
"retrieval": "partial", # Map old 'retrieval' to 'partial' "retrieval": "partial", # Map old 'retrieval' to 'partial'
"partial": "partial", "partial": "partial",
@@ -378,13 +379,19 @@ class MemoryReflectionService:
baseline_value = "TIME" baseline_value = "TIME"
baseline = ReflectionBaseline(baseline_value) baseline = ReflectionBaseline(baseline_value)
# iteration_period = # iteration_period
iteration_period = config_data.get("iteration_period", 24) iteration_period = config_data.get("iteration_period", 24)
if isinstance(iteration_period, str): if isinstance(iteration_period, str):
try: try:
iteration_period = int(iteration_period) iteration_period = int(iteration_period)
except (ValueError, TypeError): except (ValueError, TypeError):
iteration_period = 24 # 默认24小时 iteration_period = 24 # 默认24小时
# 获取 model_id 并转换为字符串(如果是 UUID 对象)
reflection_model_id = config_data.get("reflection_model_id", "")
if reflection_model_id:
reflection_model_id = str(reflection_model_id)
return ReflectionConfig( return ReflectionConfig(
enabled=config_data.get("enable_self_reflexion", False), enabled=config_data.get("enable_self_reflexion", False),
iteration_period=str(iteration_period), # ReflectionConfig期望字符串 iteration_period=str(iteration_period), # ReflectionConfig期望字符串
@@ -392,7 +399,7 @@ class MemoryReflectionService:
baseline=baseline, baseline=baseline,
memory_verify=config_data.get("memory_verify", False), memory_verify=config_data.get("memory_verify", False),
quality_assessment=config_data.get("quality_assessment", False), quality_assessment=config_data.get("quality_assessment", False),
model_id=config_data.get("reflection_model_id", "") model_id=reflection_model_id
) )
async def _execute_reflection_engine( async def _execute_reflection_engine(

View File

@@ -211,7 +211,8 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
// Process parser_config data, set default values if not present // Process parser_config data, set default values if not present
const recordAny = record as any; const recordAny = record as any;
baseValues.parser_config = record.parser_config || { baseValues.parser_config = {
...record.parser_config,
graphrag: { graphrag: {
use_graphrag: false, use_graphrag: false,
scene_name: '', scene_name: '',
@@ -219,6 +220,7 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
method: 'general', method: 'general',
resolution: false, resolution: false,
community: false, community: false,
...(record.parser_config?.graphrag || {})
} }
}; };
@@ -656,7 +658,7 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
{currentType !== 'Folder' && dynamicTypeList.map((tp) => { {currentType !== 'Folder' && dynamicTypeList.map((tp) => {
const fieldKey = typeToFieldKey(tp); const fieldKey = typeToFieldKey(tp);
// When tp is 'llm', merge llm and chat options // When tp is 'llm', merge llm and chat options
const options = tp.toLowerCase() === 'llm' const options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])] ? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
: modelOptionsByType[tp] || []; : modelOptionsByType[tp] || [];
return ( return (