From 7acb7045f081046de2f2d381c26eb833062ff39c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 30 Mar 2026 11:47:58 +0800 Subject: [PATCH] feat(agent, memory): add agent-perceived memory writing --- .../controllers/public_share_controller.py | 80 -------- api/app/core/agent/langchain_agent.py | 90 ++------- .../langgraph_graph/routing/write_router.py | 83 +++------ .../agent/langgraph_graph/write_graph.py | 116 ++++-------- api/app/core/memory/agent/utils/redis_tool.py | 173 +++++++++--------- api/app/core/memory/llm_tools/llm_client.py | 2 +- api/app/schemas/memory_agent_schema.py | 10 +- api/app/services/app_chat_service.py | 65 ++++--- api/app/services/draft_run_service.py | 13 +- api/app/services/memory_perceptual_service.py | 21 --- api/app/services/model_service.py | 138 +++++++------- api/app/services/shared_chat_service.py | 43 ++--- 12 files changed, 304 insertions(+), 530 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index f5284b46..26902b07 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -410,30 +410,6 @@ async def chat( agent_config = agent_config_4_app_release(release) if payload.stream: - # async def event_generator(): - # async for event in service.chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) async def event_generator(): async for event in app_chat_service.agnet_chat_stream( message=payload.message, @@ -459,20 +435,6 @@ async def chat( "X-Accel-Buffering": "no" } ) - # 非流式返回 - # result = await service.chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - # return success(data=conversation_schema.ChatResponse(**result)) result = await app_chat_service.agnet_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID @@ -531,48 +493,6 @@ async def chat( ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) - # 多 Agent 流式返回 - # if payload.stream: - # async def event_generator(): - # async for event in service.multi_agent_chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) - - # # 多 Agent 非流式返回 - # result = await service.multi_agent_chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - - # return success(data=conversation_schema.ChatResponse(**result)) elif app_type == AppType.WORKFLOW: config = workflow_config_4_app_release(release) if not config.id: diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 464a668a..38821313 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,18 +11,14 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term -from app.db import get_db -from app.core.logging_config import get_business_logger -from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType, ModelProvider -from app.services.memory_agent_service import ( - get_end_user_connected_config, -) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool +from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models.models_model import ModelType + logger = get_business_logger() @@ -226,10 +222,9 @@ class LangChainAgent: Returns: List[BaseMessage]: 消息列表 """ - messages = [] + messages:list = [SystemMessage(content=self.system_prompt)] # 添加系统提示词 - messages.append(SystemMessage(content=self.system_prompt)) # 添加历史消息 if history: @@ -293,12 +288,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, # 添加这个参数 - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, Any]: """执行对话 @@ -306,32 +296,12 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] context: 上下文信息(如知识库检索结果) + files: 多模态文件 Returns: Dict: 包含 content 和元数据的字典 """ - message_chat = message start_time = time.time() - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -419,9 +389,6 @@ class LangChainAgent: logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, - actual_config_id) response = { "content": content, "model": self.model_name, @@ -452,12 +419,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> AsyncGenerator[str, None]: """执行流式对话 @@ -465,6 +427,7 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 context: 上下文信息 + files: 多模态文件 Yields: str: 消息内容块 @@ -475,23 +438,6 @@ class LangChainAgent: logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info("=" * 80) - message_chat = message - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -501,17 +447,18 @@ class LangChainAgent: ) chunk_count = 0 - yielded_content = False # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") full_content = '' try: + last_event = {} async for event in self.agent.astream_events( {"messages": messages}, version="v2", config={"recursion_limit": self.max_iterations} ): + last_event = event chunk_count += 1 kind = event.get("event") @@ -525,7 +472,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -536,18 +482,15 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif kind == "on_llm_stream": # 另一种 LLM 流式事件 @@ -558,7 +501,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -569,22 +511,18 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif isinstance(chunk, str): full_content += chunk yield chunk - yielded_content = True # 记录工具调用(可选) elif kind == "on_tool_start": @@ -594,7 +532,7 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 - output_messages = event.get("data", {}).get("output", {}).get("messages", []) + output_messages = last_event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None @@ -604,9 +542,7 @@ class LangChainAgent: ) if response_meta else 0 yield total_tokens break - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, - actual_config_id) + except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 2074b6ca..74fb6bae 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id @@ -21,25 +20,6 @@ logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): - """ - Write messages to RAG storage system - - Combines user and AI messages into a single string format and stores them - in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. - - Args: - end_user_id: User identifier for the conversation - user_message: User's input message content - ai_message: AI's response message content - user_rag_memory_id: RAG memory identifier for storage location - """ - # RAG mode: combine messages into string format (maintain original logic) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - - async def write( storage_type, end_user_id, @@ -118,7 +98,7 @@ async def write( logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') -async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): +async def term_memory_save(end_user_id, strategy_type, scope): """ Save long-term memory data to database @@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty to long-term memory storage. Args: - long_term_messages: Long-term message data to be saved - actual_config_id: Configuration identifier for memory settings end_user_id: User identifier for memory association - type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) + strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) scope: Scope/window size for memory processing """ with get_db_context() as db_session: @@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) - if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + if not result: + logger.warning(f"No write data found for user {end_user_id}") + return + if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]: data = await format_parsing(result, "dict") chunk_data = data[:scope] if len(chunk_data) == scope: @@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty logger.info(f'写入短长期:') -"""Window-based dialogue processing""" - - async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): """ Process dialogue based on window size and write to Neo4j @@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) langchain_messages: Original message data list scope: Window size determining when to trigger long-term storage """ - scope = scope - is_end_user_id = count_store.get_sessions_count(end_user_id) - if is_end_user_id is not False: - is_end_user_id = count_store.get_sessions_count(end_user_id)[0] - redis_messages = count_store.get_sessions_count(end_user_id)[1] - if is_end_user_id and int(is_end_user_id) != int(scope): - is_end_user_id += 1 - langchain_messages += redis_messages - count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) - elif int(is_end_user_id) == int(scope): + is_end_user_has_history = count_store.get_sessions_count(end_user_id) + if is_end_user_has_history: + end_user_visit_count, redis_messages = is_end_user_has_history + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + return + end_user_visit_count += 1 + if end_user_visit_count < scope: + redis_messages.extend(langchain_messages) + count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages) + else: logger.info('写入长期记忆NEO4J') - formatted_messages = redis_messages + redis_messages.extend(langchain_messages) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id else: config_id = memory_config - await write( - AgentMemory_Long_Term.STORAGE_NEO4J, - end_user_id, - "", - "", - None, - end_user_id, - config_id, - formatted_messages + write_message_task.delay( + end_user_id, # end_user_id: User ID + redis_messages, # message: JSON string format message list + config_id, # config_id: Configuration ID string + AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) ) - count_store.update_sessions_count(end_user_id, 1, langchain_messages) - else: - count_store.save_sessions_count(end_user_id, 1, langchain_messages) - - -"""Time-based memory processing""" + count_store.update_sessions_count(end_user_id, 0, []) async def memory_long_term_storage(end_user_id, memory_config, time): @@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config return result_dict except Exception as e: - print(f"[aggregate_judgment] 发生错误: {e}") - import traceback - traceback.print_exc() + logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True) return { "is_same_event": False, diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index bf3c6597..32fc7d8a 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,49 +1,25 @@ -import asyncio -import json -import sys import warnings -from contextlib import asynccontextmanager -from langgraph.constants import END, START -from langgraph.graph import StateGraph -from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ + aggregate_judgment +from app.core.memory.agent.utils.redis_tool import write_store +from app.db import get_db_context from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService +from app.services.memory_konwledges_server import write_rag warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) -if sys.platform.startswith("win"): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - -@asynccontextmanager -async def make_write_graph(): - """ - Create a write graph workflow for memory operations. - - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - end_user_id: Group identifier - memory_config: MemoryConfig object containing all configuration - """ - workflow = StateGraph(WriteState) - workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "save_neo4j") - workflow.add_edge("save_neo4j", END) - - graph = workflow.compile() - - yield graph - - -async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', - end_user_id: str = '', scope: int = 6): +async def long_term_storage( + long_term_type: str, + langchain_messages: list, + memory_config_id: str, + end_user_id: str, + scope: int = 6 +): """ Handle long-term memory storage with different strategies @@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l Args: long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') langchain_messages: List of messages to store - memory_config: Memory configuration identifier + memory_config_id: Memory configuration identifier end_user_id: User group identifier scope: Scope parameter for chunk-based storage (default: 6) """ - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ - aggregate_judgment - from app.core.memory.agent.utils.redis_tool import write_store + if langchain_messages is None: + langchain_messages = [] + write_store.save_session_write(end_user_id, langchain_messages) # 获取数据库会话 with get_db_context() as db_session: config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 + config_id=memory_config_id, # 改为整数 service_name="MemoryAgentService" ) if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: - '''Strategy 1: Dialogue window with 6 rounds of conversation''' + # Dialogue window with 6 rounds of conversation await window_dialogue(end_user_id, langchain_messages, memory_config, scope) if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: - """Time-based strategy""" + # Time-based strategy await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: - """Strategy 3: Aggregate judgment""" + # Aggregate judgment await aggregate_judgment(end_user_id, langchain_messages, memory_config) -async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): +async def write_long_term( + storage_type: str, + end_user_id: str, + messages: list[dict], + user_rag_memory_id: str, + actual_config_id: str +): """ Write long-term memory with different storage types @@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u Args: storage_type: Type of storage (RAG or traditional) end_user_id: User group identifier - message_chat: User message content - aimessages: AI response messages + messages: message list user_rag_memory_id: RAG memory identifier actual_config_id: Actual configuration ID """ - from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save - from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: - await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + message_content = [] + for message in messages: + message_content.append(f'{message.get("role")}:{message.get("content")}') + messages_string = "\n".join(message_content) + await write_rag(end_user_id, messages_string, user_rag_memory_id) else: # AI reply writing (user messages and AI replies paired, written as complete dialogue at once) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE - long_term_messages = await agent_chat_messages(message_chat, aimessages) - await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) - await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) - -# async def main(): -# """主函数 - 运行工作流""" -# langchain_messages = [ -# { -# "role": "user", -# "content": "今天周五去爬山" -# }, -# { -# "role": "assistant", -# "content": "好耶" -# } -# -# ] -# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID -# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# -# -# -# if __name__ == "__main__": -# import asyncio -# asyncio.run(main()) + await long_term_storage(long_term_type=CHUNK, + langchain_messages=messages, + memory_config_id=actual_config_id, + end_user_id=end_user_id, + scope=SCOPE) + await term_memory_save(end_user_id, CHUNK, scope=SCOPE) diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index c5729628..82b22c9e 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -3,8 +3,9 @@ import uuid from app.core.config import settings from typing import List, Dict, Any, Optional, Union +from app.core.logging_config import get_logger from app.core.memory.agent.utils.redis_base import ( - serialize_messages, + serialize_messages, deserialize_messages, fix_encoding, format_session_data, @@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import ( get_current_timestamp ) - +logger = get_logger(__name__) class RedisWriteStore: """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -66,10 +67,10 @@ class RedisWriteStore: }) result = pipe.execute() - print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session_write] 保存会话失败: {e}") + logger.error(f"[save_session_write] 保存会话失败: {e}") raise e def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: @@ -99,7 +100,7 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid: # 从 key 中提取 session_id: session:write:{session_id} @@ -108,16 +109,16 @@ class RedisWriteStore: "sessionid": session_id, "messages": fix_encoding(data.get('messages', '')) }) - + if not results: return False - - print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_session_by_userid] 查询失败: {e}") + logger.error(f"[get_session_by_userid] 查询失败: {e}") return False - + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: """ 通过 end_user_id 获取所有 write 类型的会话数据 @@ -144,7 +145,7 @@ class RedisWriteStore: # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") return False # 批量获取数据 @@ -158,12 +159,12 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == end_user_id: # 从 key 中提取 session_id: session:write:{session_id} session_id = key.split(':')[-1] - + # 构建完整的会话信息 session_info = { "session_id": session_id, @@ -173,23 +174,21 @@ class RedisWriteStore: "starttime": data.get('starttime', '') } results.append(session_info) - + if not results: - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") return False - + # 按时间排序(最新的在前) results.sort(key=lambda x: x.get('starttime', ''), reverse=True) - - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") - import traceback - traceback.print_exc() + logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True) return False - def find_user_recent_sessions(self, userid: str, + def find_user_recent_sessions(self, userid: str, minutes: int = 5) -> List[Dict[str, str]]: """ 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 @@ -203,11 +202,11 @@ class RedisWriteStore: """ import time start_time = time.time() - + # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -221,7 +220,7 @@ class RedisWriteStore: for data in all_data: if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid and data.get('starttime'): # write 类型没有 aimessages,所以 Answer 为空 @@ -230,15 +229,14 @@ class RedisWriteStore: "Answer": "", "starttime": data.get('starttime', '') }) - + # 根据时间范围过滤 filtered_items = filter_by_time_range(matched_items, minutes) # 排序并移除时间字段 - result_items = sort_and_limit_results(filtered_items, limit=None) - print(result_items) + result_items = sort_and_limit_results(filtered_items) elapsed_time = time.time() - start_time - print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items @@ -258,7 +256,7 @@ class RedisWriteStore: class RedisCountStore: """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -278,7 +276,7 @@ class RedisCountStore: decode_responses=True, encoding='utf-8' ) - self.uudi = session_id + self.uuid = session_id def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: """ @@ -295,26 +293,26 @@ class RedisCountStore: session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") index_key = f'session:count:index:{end_user_id}' # 索引键 - + pipe = self.r.pipeline() pipe.hset(key, mapping={ - "id": self.uudi, + "id": self.uuid, "end_user_id": end_user_id, "count": int(count), "messages": serialize_messages(messages), "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 - + # 创建索引:end_user_id -> session_id 映射 pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) - + result = pipe.execute() - - print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + + logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") return session_id - def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool: """ 通过 end_user_id 查询访问次数统计 @@ -327,7 +325,7 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) @@ -335,35 +333,40 @@ class RedisCountStore: self.r.delete(index_key) return False except Exception as type_error: - print(f"[get_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: return False - + # 直接获取数据 key = generate_session_key(session_id, key_type="count") data = self.r.hgetall(key) - + if not data: # 索引存在但数据不存在,清理索引 self.r.delete(index_key) return False - + count = data.get('count') messages_str = data.get('messages') - + if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] - + messages: list[dict] = deserialize_messages(messages_str) + return int(count), messages + return False except Exception as e: - print(f"[get_sessions_count] 查询失败: {e}") + logger.error(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, - messages: Any) -> bool: + + def update_sessions_count( + self, + end_user_id: str, + new_count: int, + messages: Any + ) -> bool: """ 通过 end_user_id 修改访问次数统计(优化版:使用索引) @@ -378,39 +381,39 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) if key_type != 'string' and key_type != 'none': # 索引键类型错误,删除并返回 False - print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") self.r.delete(index_key) - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False except Exception as type_error: - print(f"[update_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False - + # 直接更新数据 key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - + pipe = self.r.pipeline() - pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'count', str(new_count)) pipe.hset(key, 'messages', messages_str) result = pipe.execute() - - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + + logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") return True - + except Exception as e: - print(f"[update_sessions_count] 更新失败: {e}") + logger.debug(f"[update_sessions_count] 更新失败: {e}") return False def delete_all_count_sessions(self) -> int: @@ -428,7 +431,7 @@ class RedisCountStore: class RedisSessionStore: """Redis 会话存储类,用于管理会话数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -451,9 +454,9 @@ class RedisSessionStore: self.uudi = session_id # ==================== 写入操作 ==================== - - def save_session(self, userid: str, messages: str, aimessages: str, - apply_id: str, end_user_id: str) -> str: + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: """ 写入一条会话数据,返回 session_id @@ -483,14 +486,14 @@ class RedisSessionStore: }) result = pipe.execute() - print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session] 保存会话失败: {e}") + logger.error(f"[save_session] 保存会话失败: {e}") raise e # ==================== 读取操作 ==================== - + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 @@ -520,8 +523,8 @@ class RedisSessionStore: sessions[sid] = self.get_session(sid) return sessions - def find_user_apply_group(self, sessionid: str, apply_id: str, - end_user_id: str) -> List[Dict[str, str]]: + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: """ 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 @@ -535,10 +538,10 @@ class RedisSessionStore: """ import time start_time = time.time() - + keys = self.r.keys('session:*') if not keys: - print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -556,21 +559,21 @@ class RedisSessionStore: continue if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): + data.get('end_user_id') == end_user_id): # 支持模糊匹配或完全匹配 sessionid if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: matched_items.append(format_session_data(data, include_time=True)) - + # 排序、限制数量并移除时间字段 result_items = sort_and_limit_results(matched_items, limit=6) elapsed_time = time.time() - start_time - print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items # ==================== 更新操作 ==================== - + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 @@ -591,7 +594,7 @@ class RedisSessionStore: return bool(results[0]) # ==================== 删除操作 ==================== - + def delete_session(self, session_id: str) -> int: """ 删除单条会话 @@ -632,7 +635,7 @@ class RedisSessionStore: keys = self.r.keys('session:*') if not keys: - print("[delete_duplicate_sessions] 没有会话数据") + logger.debug("[delete_duplicate_sessions] 没有会话数据") return 0 # 批量获取所有数据 @@ -678,7 +681,7 @@ class RedisSessionStore: deleted_count += len(batch) elapsed_time = time.time() - start_time - print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") + logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count diff --git a/api/app/core/memory/llm_tools/llm_client.py b/api/app/core/memory/llm_tools/llm_client.py index e26aba3e..49cd9434 100644 --- a/api/app/core/memory/llm_tools/llm_client.py +++ b/api/app/core/memory/llm_tools/llm_client.py @@ -56,7 +56,7 @@ class LLMClient(ABC): self.max_retries = self.config.max_retries self.timeout = self.config.timeout - logger.info( + logger.debug( f"初始化 LLM 客户端: provider={self.provider}, " f"model={self.model_name}, max_retries={self.max_retries}" ) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b4efe61d..97aa5bb5 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -17,6 +17,7 @@ class Write_UserInput(BaseModel): end_user_id: str config_id: Optional[str] = None + class AgentMemory_Long_Term(ABC): """长期记忆配置常量""" STORAGE_NEO4J = "neo4j" @@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC): STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 - TIME_SCOPE=5 -class AgentMemoryDataset(ABC): - PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余'] - NAME='用户' + TIME_SCOPE = 5 + +class AgentMemoryDataset(ABC): + PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余'] + NAME = '用户' diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 90474428..17c2f98c 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from app.core.agent.langchain_agent import LangChainAgent from app.core.logging_config import get_business_logger +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import WorkflowConfig @@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService from app.services.draft_run_service import AgentRunService +from app.services.memory_agent_service import get_end_user_connected_config from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService from app.services.workflow_service import WorkflowService -from app.schemas import FileType logger = get_business_logger() @@ -43,18 +44,17 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, - user_id: Optional[str] = None, + files: list[FileInput], + user_id: str, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> Dict[str, Any]: """聊天(非流式)""" start_time = time.time() - config_id = None # 应用 features 配置 features_config: dict = config.features or {} @@ -93,7 +93,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, + user_id) tools.extend(kb_tools) memory_flag = False if memory: @@ -168,11 +169,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -229,6 +225,21 @@ class AppChatService: # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": result["content"]} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -264,20 +275,19 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, + files: list[FileInput], user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """聊天(流式)""" try: start_time = time.time() - config_id = None message_id = uuid.uuid4() # 应用 features 配置 @@ -319,7 +329,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config( + config.knowledge_retrieval, user_id) tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -411,11 +422,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): @@ -459,7 +465,7 @@ class AppChatService: # 保存消息 human_meta = { - "files":[], + "files": [], "history_files": {} } assistant_meta = { @@ -484,6 +490,22 @@ class AppChatService: if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url + + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": full_content} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -618,7 +640,6 @@ class AppChatService: # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) - # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=message, diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index e188872f..aef54847 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context -from app.models import AgentConfig, ModelConfig, ModelType +from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo @@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_service import ModelApiKeyService from app.services.multimodal_service import MultimodalService from app.services.tool_service import ToolService -from app.schemas import FileType logger = get_business_logger() @@ -657,11 +656,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -911,11 +905,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 3ee238e2..5c838fc0 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -243,27 +243,6 @@ class MemoryPerceptualService: memory_config: MemoryConfig, file: FileInput ): - memories = self.repository.get_by_url(file.url) - if memories: - business_logger.info(f"Perceptual memory already exists: {file.url}") - if end_user_id not in [memory.end_user_id for memory in memories]: - business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") - memory_cache = memories[0] - memory = self.repository.create_perceptual_memory( - end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType(memory_cache.perceptual_type), - file_path=memory_cache.file_path, - file_name=memory_cache.file_name, - file_ext=memory_cache.file_ext, - summary=memory_cache.summary, - meta_data=memory_cache.meta_data - ) - self.db.commit() - return memory - else: - for memory in memories: - if memory.end_user_id == uuid.UUID(end_user_id): - return memory llm, model_config = self._get_mutlimodal_client(file.type, memory_config) multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index b98674ba..c9266667 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -69,7 +69,8 @@ class ModelConfigService: return items @staticmethod - def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def get_model_by_name(db: Session, name: str, provider: str | None = None, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: @@ -77,21 +78,22 @@ class ModelConfigService: return model @staticmethod - def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: + def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ + ModelConfig]: """按名称模糊匹配获取模型配置列表""" return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) @staticmethod async def validate_model_config( - db: Session, - *, - model_name: str, - provider: str, - api_key: str, - api_base: Optional[str] = None, - model_type: str = "llm", - test_message: str = "Hello", - is_omni: bool = False + db: Session, + *, + model_name: str, + provider: str, + api_key: str, + api_base: Optional[str] = None, + model_type: str = "llm", + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -158,13 +160,13 @@ class ModelConfigService: # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - + # 火山引擎使用 embed_batch,其他使用 embed_documents if provider.lower() == "volcano": vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) else: vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) - + elapsed_time = time.time() - start_time return { @@ -200,11 +202,11 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "image": # 图片生成模型验证 from app.core.models.generation import RedBearImageGenerator - + generator = RedBearImageGenerator(model_config) result = await generator.agenerate( prompt="a cute panda", @@ -212,7 +214,7 @@ class ModelConfigService: ) elapsed_time = time.time() - start_time logger.info(f"成功生成图片,结果: {result}") - + return { "valid": True, "message": "图片生成模型配置验证成功", @@ -224,21 +226,21 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "video": # 视频生成模型验证 from app.core.models.generation import RedBearVideoGenerator - + generator = RedBearVideoGenerator(model_config) result = await generator.agenerate( prompt="a cute panda playing in bamboo forest", duration=5 ) elapsed_time = time.time() - start_time - + # 视频生成是异步任务,返回任务ID task_id = result.get("task_id") if isinstance(result, dict) else None - + return { "valid": True, "message": "视频生成模型配置验证成功", @@ -265,7 +267,6 @@ class ModelConfigService: # 提取详细的错误信息 error_message = str(e) error_type = type(e).__name__ - print("=========error_message:",error_message.lower()) # 特殊处理常见的错误类型 if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): # 区域/国家限制(适用于所有提供商) @@ -354,14 +355,16 @@ class ModelConfigService: return model @staticmethod - def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """更新模型配置""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) @@ -370,25 +373,27 @@ class ModelConfigService: return model @staticmethod - async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" - if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + # 检查 API Key 关联的模型配置类型 for model_config in api_key.model_configs: # chat 和 llm 类型可以兼容 compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = model_data.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", @@ -399,7 +404,7 @@ class ModelConfigService: # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # BizCode.INVALID_PARAMETER # ) - + # 创建组合模型 model_config_data = { "tenant_id": tenant_id, @@ -418,49 +423,51 @@ class ModelConfigService: model = ModelConfigRepository.create(db, model_config_data) db.flush() - + # 关联 API Keys for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: model.api_keys.append(api_key) - + db.commit() db.refresh(model) return model @staticmethod - async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """更新组合模型""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + for model_config in api_key.model_configs: compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = existing_model.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", BizCode.INVALID_PARAMETER ) - + # 更新基本信息 existing_model.name = model_data.name # existing_model.type = model_data.type @@ -471,14 +478,14 @@ class ModelConfigService: existing_model.is_public = model_data.is_public if "load_balance_strategy" in model_data.model_fields_set: existing_model.load_balance_strategy = model_data.load_balance_strategy - + # 更新 API Keys 关联 existing_model.api_keys.clear() for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: existing_model.api_keys.append(api_key) - + db.commit() db.refresh(existing_model) return existing_model @@ -532,7 +539,7 @@ class ModelApiKeyService: """根据provider为多个ModelConfig创建API Key""" created_keys = [] failed_models = [] # 记录验证失败的模型 - + for model_config_id in data.model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: @@ -540,10 +547,10 @@ class ModelApiKeyService: data.is_omni = model_config.is_omni data.capability = model_config.capability - + # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name - + # 检查是否存在API Key(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -553,7 +560,7 @@ class ModelApiKeyService: ModelApiKey.model_name == model_name, ModelConfig.tenant_id == model_config.tenant_id ).first() - + if existing_key: # 如果已存在,重新激活并更新 if existing_key.is_active: @@ -566,14 +573,14 @@ class ModelApiKeyService: existing_key.model_name = model_name existing_key.capability = data.capability existing_key.is_omni = data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + created_keys.append(existing_key) continue - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -589,7 +596,7 @@ class ModelApiKeyService: # 记录验证失败的模型,但不抛出异常 failed_models.append(model_name) continue - + # 创建API Key api_key_data = ModelApiKeyCreate( model_config_ids=[model_config_id], @@ -606,12 +613,12 @@ class ModelApiKeyService: ) api_key_obj = ModelApiKeyRepository.create(db, api_key_data) created_keys.append(api_key_obj) - + if created_keys: db.commit() for key in created_keys: db.refresh(key) - + return created_keys, failed_models @staticmethod @@ -626,7 +633,7 @@ class ModelApiKeyService: api_key_data.is_omni = model_config.is_omni if api_key_data.capability is None: api_key_data.capability = model_config.capability - + # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -650,15 +657,15 @@ class ModelApiKeyService: existing_key.model_name = api_key_data.model_name existing_key.capability = api_key_data.capability existing_key.is_omni = api_key_data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + db.commit() db.refresh(existing_key) return existing_key - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -691,7 +698,7 @@ class ModelApiKeyService: # 获取关联的模型配置以获取模型类型 if existing_api_key.model_configs: model_config = existing_api_key.model_configs[0] - + validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name or existing_api_key.model_name, @@ -729,15 +736,15 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: return None - + api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: return None - + # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) - + # 否则返回第一个 return api_keys[0] @@ -760,20 +767,19 @@ class ModelApiKeyService: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - class ModelBaseService: """基础模型服务""" @staticmethod def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: models = ModelBaseRepository.get_list(db, query) - + provider_groups = {} for m in models: model_dict = model_schema.ModelBase.model_validate(m).model_dump() if tenant_id: model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) - + provider = m.provider if provider not in provider_groups: provider_groups[provider] = { @@ -781,7 +787,7 @@ class ModelBaseService: "models": [] } provider_groups[provider]["models"].append(model_dict) - + return list(provider_groups.values()) @staticmethod @@ -823,10 +829,10 @@ class ModelBaseService: model_base = ModelBaseRepository.get_by_id(db, model_base_id) if not model_base: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) - + if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) - + model_config_data = { "model_id": model_base_id, "tenant_id": tenant_id, diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 0d659832..c74604a5 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -1,26 +1,24 @@ """基于分享链接的聊天服务""" -import uuid -import time import asyncio +import json +import time +import uuid from typing import Optional, Dict, Any, AsyncGenerator + +from deprecated import deprecated from sqlalchemy.orm import Session -from app.repositories.model_repository import ModelApiKeyRepository -from app.services.memory_konwledges_server import write_rag +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException, ResourceNotFoundException +from app.core.logging_config import get_business_logger +from app.models import MultiAgentConfig from app.models import ReleaseShare, AppRelease, Conversation +from app.repositories import knowledge_repository from app.services.conversation_service import ConversationService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService -from app.services.release_share_service import ReleaseShareService -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger from app.services.multi_agent_service import MultiAgentService -from app.models import MultiAgentConfig -from app.repositories import knowledge_repository -import json -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task +from app.services.release_share_service import ReleaseShareService logger = get_business_logger() @@ -118,6 +116,7 @@ class SharedChatService: return conversation + @deprecated("Use the chat method under app_chat_service instead.") async def chat( self, share_token: str, @@ -136,10 +135,7 @@ class SharedChatService: config_id = actual_config_id from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool - from app.services.model_parameter_merger import ModelParameterMerger from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey start_time = time.time() actual_config_id = None @@ -273,11 +269,6 @@ class SharedChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag ) # 保存消息 @@ -324,6 +315,7 @@ class SharedChatService: "elapsed_time": elapsed_time } + @deprecated("Use the chat method under app_chat_service instead.") async def chat_stream( self, share_token: str, @@ -341,8 +333,6 @@ class SharedChatService: from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey import json start_time = time.time() @@ -486,11 +476,6 @@ class SharedChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag ): if isinstance(chunk, int): total_tokens = chunk @@ -585,6 +570,7 @@ class SharedChatService: return conversations, total + @deprecated("Use the chat method under app_chat_service instead.") async def multi_agent_chat( self, share_token: str, @@ -680,6 +666,7 @@ class SharedChatService: "elapsed_time": elapsed_time } + @deprecated("Use the chat method under app_chat_service instead.") async def multi_agent_chat_stream( self, share_token: str,