feat(agent, memory): add agent-perceived memory writing

This commit is contained in:
Eternity
2026-03-30 11:47:58 +08:00
parent a5bce221bd
commit 7acb7045f0
12 changed files with 304 additions and 530 deletions

View File

@@ -410,30 +410,6 @@ async def chat(
agent_config = agent_config_4_app_release(release) agent_config = agent_config_4_app_release(release)
if payload.stream: if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator(): async def event_generator():
async for event in app_chat_service.agnet_chat_stream( async for event in app_chat_service.agnet_chat_stream(
message=payload.message, message=payload.message,
@@ -459,20 +435,6 @@ async def chat(
"X-Accel-Buffering": "no" "X-Accel-Buffering": "no"
} }
) )
# 非流式返回
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat( result = await app_chat_service.agnet_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
@@ -531,48 +493,6 @@ async def chat(
) )
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW: elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release) config = workflow_config_4_app_release(release)
if not config.id: if not config.id:

View File

@@ -11,18 +11,14 @@ LangChain Agent 封装
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
logger = get_business_logger() logger = get_business_logger()
@@ -226,10 +222,9 @@ class LangChainAgent:
Returns: Returns:
List[BaseMessage]: 消息列表 List[BaseMessage]: 消息列表
""" """
messages = [] messages:list = [SystemMessage(content=self.system_prompt)]
# 添加系统提示词 # 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息 # 添加历史消息
if history: if history:
@@ -293,12 +288,7 @@ class LangChainAgent:
message: str, message: str,
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
end_user_id: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""执行对话 """执行对话
@@ -306,32 +296,12 @@ class LangChainAgent:
message: 用户消息 message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果) context: 上下文信息(如知识库检索结果)
files: 多模态文件
Returns: Returns:
Dict: 包含 content 和元数据的字典 Dict: 包含 content 和元数据的字典
""" """
message_chat = message
start_time = time.time() start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try: try:
# 准备消息列表(支持多模态) # 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files) messages = self._prepare_messages(message, history, context, files)
@@ -419,9 +389,6 @@ class LangChainAgent:
logger.info(f"最终提取的内容长度: {len(content)}") logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = { response = {
"content": content, "content": content,
"model": self.model_name, "model": self.model_name,
@@ -452,12 +419,7 @@ class LangChainAgent:
message: str, message: str,
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
end_user_id: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""执行流式对话 """执行流式对话
@@ -465,6 +427,7 @@ class LangChainAgent:
message: 用户消息 message: 用户消息
history: 历史消息列表 history: 历史消息列表
context: 上下文信息 context: 上下文信息
files: 多模态文件
Yields: Yields:
str: 消息内容块 str: 消息内容块
@@ -475,23 +438,6 @@ class LangChainAgent:
logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80) logger.info("=" * 80)
message_chat = message
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try: try:
# 准备消息列表(支持多模态) # 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files) messages = self._prepare_messages(message, history, context, files)
@@ -501,17 +447,18 @@ class LangChainAgent:
) )
chunk_count = 0 chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出 # 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出") logger.debug("使用 Agent astream_events 实现流式输出")
full_content = '' full_content = ''
try: try:
last_event = {}
async for event in self.agent.astream_events( async for event in self.agent.astream_events(
{"messages": messages}, {"messages": messages},
version="v2", version="v2",
config={"recursion_limit": self.max_iterations} config={"recursion_limit": self.max_iterations}
): ):
last_event = event
chunk_count += 1 chunk_count += 1
kind = event.get("event") kind = event.get("event")
@@ -525,7 +472,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content: if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content full_content += chunk_content
yield chunk_content yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list): elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分 # 多模态响应:提取文本部分
for item in chunk_content: for item in chunk_content:
@@ -536,18 +482,15 @@ class LangChainAgent:
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."} # OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text": elif item.get("type") == "text":
text = item.get("text", "") text = item.get("text", "")
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
elif isinstance(item, str): elif isinstance(item, str):
full_content += item full_content += item
yield item yield item
yielded_content = True
elif kind == "on_llm_stream": elif kind == "on_llm_stream":
# 另一种 LLM 流式事件 # 另一种 LLM 流式事件
@@ -558,7 +501,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content: if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content full_content += chunk_content
yield chunk_content yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list): elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分 # 多模态响应:提取文本部分
for item in chunk_content: for item in chunk_content:
@@ -569,22 +511,18 @@ class LangChainAgent:
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."} # OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text": elif item.get("type") == "text":
text = item.get("text", "") text = item.get("text", "")
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
elif isinstance(item, str): elif isinstance(item, str):
full_content += item full_content += item
yield item yield item
yielded_content = True
elif isinstance(chunk, str): elif isinstance(chunk, str):
full_content += chunk full_content += chunk
yield chunk yield chunk
yielded_content = True
# 记录工具调用(可选) # 记录工具调用(可选)
elif kind == "on_tool_start": elif kind == "on_tool_start":
@@ -594,7 +532,7 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗 # 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", []) output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
@@ -604,9 +542,7 @@ class LangChainAgent:
) if response_meta else 0 ) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e: except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise raise

View File

@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write( async def write(
storage_type, storage_type,
end_user_id, end_user_id,
@@ -118,7 +98,7 @@ async def write(
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): async def term_memory_save(end_user_id, strategy_type, scope):
""" """
Save long-term memory data to database Save long-term memory data to database
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
to long-term memory storage. to long-term memory storage.
Args: Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing scope: Scope/window size for memory processing
""" """
with get_db_context() as db_session: with get_db_context() as db_session:
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id) result = write_store.get_session_by_userid(end_user_id)
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: if not result:
logger.warning(f"No write data found for user {end_user_id}")
return
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
data = await format_parsing(result, "dict") data = await format_parsing(result, "dict")
chunk_data = data[:scope] chunk_data = data[:scope]
if len(chunk_data) == scope: if len(chunk_data) == scope:
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
logger.info(f'写入短长期:') logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
""" """
Process dialogue based on window size and write to Neo4j Process dialogue based on window size and write to Neo4j
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
langchain_messages: Original message data list langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage scope: Window size determining when to trigger long-term storage
""" """
scope = scope is_end_user_has_history = count_store.get_sessions_count(end_user_id)
is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_has_history:
if is_end_user_id is not False: end_user_visit_count, redis_messages = is_end_user_has_history
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] else:
redis_messages = count_store.get_sessions_count(end_user_id)[1] count_store.save_sessions_count(end_user_id, 1, langchain_messages)
if is_end_user_id and int(is_end_user_id) != int(scope): return
is_end_user_id += 1 end_user_visit_count += 1
langchain_messages += redis_messages if end_user_visit_count < scope:
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) redis_messages.extend(langchain_messages)
elif int(is_end_user_id) == int(scope): count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
else:
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = redis_messages redis_messages.extend(langchain_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id
else: else:
config_id = memory_config config_id = memory_config
await write( write_message_task.delay(
AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, # end_user_id: User ID
end_user_id, redis_messages, # message: JSON string format message list
"", config_id, # config_id: Configuration ID string
"", AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
None, "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
end_user_id,
config_id,
formatted_messages
) )
count_store.update_sessions_count(end_user_id, 1, langchain_messages) count_store.update_sessions_count(end_user_id, 0, [])
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time): async def memory_long_term_storage(end_user_id, memory_config, time):
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
return result_dict return result_dict
except Exception as e: except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}") logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
import traceback
traceback.print_exc()
return { return {
"is_same_event": False, "is_same_event": False,

View File

@@ -1,49 +1,25 @@
import asyncio
import json
import sys
import warnings import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
from app.db import get_db_context
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import write_rag
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def long_term_storage(
@asynccontextmanager long_term_type: str,
async def make_write_graph(): langchain_messages: list,
""" memory_config_id: str,
Create a write graph workflow for memory operations. end_user_id: str,
scope: int = 6
Args: ):
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
""" """
Handle long-term memory storage with different strategies Handle long-term memory storage with different strategies
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
Args: Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store langchain_messages: List of messages to store
memory_config: Memory configuration identifier memory_config_id: Memory configuration identifier
end_user_id: User group identifier end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6) scope: Scope parameter for chunk-based storage (default: 6)
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ if langchain_messages is None:
aggregate_judgment langchain_messages = []
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, langchain_messages) write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话 # 获取数据库会话
with get_db_context() as db_session: with get_db_context() as db_session:
config_service = MemoryConfigService(db_session) config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数 config_id=memory_config_id, # 改为整数
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation''' # Dialogue window with 6 rounds of conversation
await window_dialogue(end_user_id, langchain_messages, memory_config, scope) await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy""" # Time-based strategy
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment""" # Aggregate judgment
await aggregate_judgment(end_user_id, langchain_messages, memory_config) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): async def write_long_term(
storage_type: str,
end_user_id: str,
messages: list[dict],
user_rag_memory_id: str,
actual_config_id: str
):
""" """
Write long-term memory with different storage types Write long-term memory with different storage types
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
Args: Args:
storage_type: Type of storage (RAG or traditional) storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier end_user_id: User group identifier
message_chat: User message content messages: message list
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID actual_config_id: Actual configuration ID
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG: if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) message_content = []
for message in messages:
message_content.append(f'{message.get("role")}:{message.get("content")}')
messages_string = "\n".join(message_content)
await write_rag(end_user_id, messages_string, user_rag_memory_id)
else: else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once) # AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages) await long_term_storage(long_term_type=CHUNK,
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, langchain_messages=messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) memory_config_id=actual_config_id,
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) end_user_id=end_user_id,
scope=SCOPE)
# async def main(): await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -3,8 +3,9 @@ import uuid
from app.core.config import settings from app.core.config import settings
from typing import List, Dict, Any, Optional, Union from typing import List, Dict, Any, Optional, Union
from app.core.logging_config import get_logger
from app.core.memory.agent.utils.redis_base import ( from app.core.memory.agent.utils.redis_base import (
serialize_messages, serialize_messages,
deserialize_messages, deserialize_messages,
fix_encoding, fix_encoding,
format_session_data, format_session_data,
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
get_current_timestamp get_current_timestamp
) )
logger = get_logger(__name__)
class RedisWriteStore: class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" """Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
""" """
初始化 Redis 连接 初始化 Redis 连接
@@ -66,10 +67,10 @@ class RedisWriteStore:
}) })
result = pipe.execute() result = pipe.execute()
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}") logger.error(f"[save_session_write] 保存会话失败: {e}")
raise e raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
@@ -99,7 +100,7 @@ class RedisWriteStore:
for key, data in zip(keys, all_data): for key, data in zip(keys, all_data):
if not data: if not data:
continue continue
# 从 write 类型读取,匹配 sessionid 字段 # 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid: if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id} # 从 key 中提取 session_id: session:write:{session_id}
@@ -108,16 +109,16 @@ class RedisWriteStore:
"sessionid": session_id, "sessionid": session_id,
"messages": fix_encoding(data.get('messages', '')) "messages": fix_encoding(data.get('messages', ''))
}) })
if not results: if not results:
return False return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results return results
except Exception as e: except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}") logger.error(f"[get_session_by_userid] 查询失败: {e}")
return False return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
""" """
通过 end_user_id 获取所有 write 类型的会话数据 通过 end_user_id 获取所有 write 类型的会话数据
@@ -144,7 +145,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key # 只查询 write 类型的 key
keys = self.r.keys('session:write:*') keys = self.r.keys('session:write:*')
if not keys: if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False return False
# 批量获取数据 # 批量获取数据
@@ -158,12 +159,12 @@ class RedisWriteStore:
for key, data in zip(keys, all_data): for key, data in zip(keys, all_data):
if not data: if not data:
continue continue
# 从 write 类型读取,匹配 sessionid 字段 # 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id: if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id} # 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1] session_id = key.split(':')[-1]
# 构建完整的会话信息 # 构建完整的会话信息
session_info = { session_info = {
"session_id": session_id, "session_id": session_id,
@@ -173,23 +174,21 @@ class RedisWriteStore:
"starttime": data.get('starttime', '') "starttime": data.get('starttime', '')
} }
results.append(session_info) results.append(session_info)
if not results: if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False return False
# 按时间排序(最新的在前) # 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True) results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results return results
except Exception as e: except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
import traceback
traceback.print_exc()
return False return False
def find_user_recent_sessions(self, userid: str, def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]: minutes: int = 5) -> List[Dict[str, str]]:
""" """
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
@@ -203,11 +202,11 @@ class RedisWriteStore:
""" """
import time import time
start_time = time.time() start_time = time.time()
# 只查询 write 类型的 key # 只查询 write 类型的 key
keys = self.r.keys('session:write:*') keys = self.r.keys('session:write:*')
if not keys: if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return [] return []
# 批量获取数据 # 批量获取数据
@@ -221,7 +220,7 @@ class RedisWriteStore:
for data in all_data: for data in all_data:
if not data: if not data:
continue continue
# 从 write 类型读取,匹配 sessionid 字段 # 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'): if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空 # write 类型没有 aimessages所以 Answer 为空
@@ -230,15 +229,14 @@ class RedisWriteStore:
"Answer": "", "Answer": "",
"starttime": data.get('starttime', '') "starttime": data.get('starttime', '')
}) })
# 根据时间范围过滤 # 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes) filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段 # 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None) result_items = sort_and_limit_results(filtered_items)
print(result_items)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items return result_items
@@ -258,7 +256,7 @@ class RedisWriteStore:
class RedisCountStore: class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据""" """Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
""" """
初始化 Redis 连接 初始化 Redis 连接
@@ -278,7 +276,7 @@ class RedisCountStore:
decode_responses=True, decode_responses=True,
encoding='utf-8' encoding='utf-8'
) )
self.uudi = session_id self.uuid = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
""" """
@@ -295,26 +293,26 @@ class RedisCountStore:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键 index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, mapping={ pipe.hset(key, mapping={
"id": self.uudi, "id": self.uuid,
"end_user_id": end_user_id, "end_user_id": end_user_id,
"count": int(count), "count": int(count),
"messages": serialize_messages(messages), "messages": serialize_messages(messages),
"starttime": get_current_timestamp() "starttime": get_current_timestamp()
}) })
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射 # 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute() result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
""" """
通过 end_user_id 查询访问次数统计 通过 end_user_id 查询访问次数统计
@@ -327,7 +325,7 @@ class RedisCountStore:
try: try:
# 使用索引键快速查找 # 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}' index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误 # 检查索引键类型,避免 WRONGTYPE 错误
try: try:
key_type = self.r.type(index_key) key_type = self.r.type(index_key)
@@ -335,35 +333,40 @@ class RedisCountStore:
self.r.delete(index_key) self.r.delete(index_key)
return False return False
except Exception as type_error: except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}") logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key) session_id = self.r.get(index_key)
if not session_id: if not session_id:
return False return False
# 直接获取数据 # 直接获取数据
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key) data = self.r.hgetall(key)
if not data: if not data:
# 索引存在但数据不存在,清理索引 # 索引存在但数据不存在,清理索引
self.r.delete(index_key) self.r.delete(index_key)
return False return False
count = data.get('count') count = data.get('count')
messages_str = data.get('messages') messages_str = data.get('messages')
if count is not None: if count is not None:
messages = deserialize_messages(messages_str) messages: list[dict] = deserialize_messages(messages_str)
return [int(count), messages] return int(count), messages
return False return False
except Exception as e: except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}") logger.error(f"[get_sessions_count] 查询失败: {e}")
return False return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool: def update_sessions_count(
self,
end_user_id: str,
new_count: int,
messages: Any
) -> bool:
""" """
通过 end_user_id 修改访问次数统计(优化版:使用索引) 通过 end_user_id 修改访问次数统计(优化版:使用索引)
@@ -378,39 +381,39 @@ class RedisCountStore:
try: try:
# 使用索引键快速查找 # 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}' index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误 # 检查索引键类型,避免 WRONGTYPE 错误
try: try:
key_type = self.r.type(index_key) key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none': if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False # 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key) self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False return False
except Exception as type_error: except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}") logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key) session_id = self.r.get(index_key)
if not session_id: if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False return False
# 直接更新数据 # 直接更新数据
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages) messages_str = serialize_messages(messages)
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count)) pipe.hset(key, 'count', str(new_count))
pipe.hset(key, 'messages', messages_str) pipe.hset(key, 'messages', messages_str)
result = pipe.execute() result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True return True
except Exception as e: except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}") logger.debug(f"[update_sessions_count] 更新失败: {e}")
return False return False
def delete_all_count_sessions(self) -> int: def delete_all_count_sessions(self) -> int:
@@ -428,7 +431,7 @@ class RedisCountStore:
class RedisSessionStore: class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据""" """Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
""" """
初始化 Redis 连接 初始化 Redis 连接
@@ -451,9 +454,9 @@ class RedisSessionStore:
self.uudi = session_id self.uudi = session_id
# ==================== 写入操作 ==================== # ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str, def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str: apply_id: str, end_user_id: str) -> str:
""" """
写入一条会话数据,返回 session_id 写入一条会话数据,返回 session_id
@@ -483,14 +486,14 @@ class RedisSessionStore:
}) })
result = pipe.execute() result = pipe.execute()
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"[save_session] 保存会话失败: {e}") logger.error(f"[save_session] 保存会话失败: {e}")
raise e raise e
# ==================== 读取操作 ==================== # ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
""" """
读取一条会话数据 读取一条会话数据
@@ -520,8 +523,8 @@ class RedisSessionStore:
sessions[sid] = self.get_session(sid) sessions[sid] = self.get_session(sid)
return sessions return sessions
def find_user_apply_group(self, sessionid: str, apply_id: str, def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]: end_user_id: str) -> List[Dict[str, str]]:
""" """
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条 根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
@@ -535,10 +538,10 @@ class RedisSessionStore:
""" """
import time import time
start_time = time.time() start_time = time.time()
keys = self.r.keys('session:*') keys = self.r.keys('session:*')
if not keys: if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return [] return []
# 批量获取数据 # 批量获取数据
@@ -556,21 +559,21 @@ class RedisSessionStore:
continue continue
if (data.get('apply_id') == apply_id and if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id): data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid # 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True)) matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段 # 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6) result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items return result_items
# ==================== 更新操作 ==================== # ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool: def update_session(self, session_id: str, field: str, value: Any) -> bool:
""" """
更新单个字段 更新单个字段
@@ -591,7 +594,7 @@ class RedisSessionStore:
return bool(results[0]) return bool(results[0])
# ==================== 删除操作 ==================== # ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int: def delete_session(self, session_id: str) -> int:
""" """
删除单条会话 删除单条会话
@@ -632,7 +635,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*') keys = self.r.keys('session:*')
if not keys: if not keys:
print("[delete_duplicate_sessions] 没有会话数据") logger.debug("[delete_duplicate_sessions] 没有会话数据")
return 0 return 0
# 批量获取所有数据 # 批量获取所有数据
@@ -678,7 +681,7 @@ class RedisSessionStore:
deleted_count += len(batch) deleted_count += len(batch)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}") logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count return deleted_count

View File

@@ -56,7 +56,7 @@ class LLMClient(ABC):
self.max_retries = self.config.max_retries self.max_retries = self.config.max_retries
self.timeout = self.config.timeout self.timeout = self.config.timeout
logger.info( logger.debug(
f"初始化 LLM 客户端: provider={self.provider}, " f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}" f"model={self.model_name}, max_retries={self.max_retries}"
) )

View File

@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
end_user_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC): class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量""" """长期记忆配置常量"""
STORAGE_NEO4J = "neo4j" STORAGE_NEO4J = "neo4j"
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_CHUNK = "chunk" STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time" STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6 DEFAULT_SCOPE = 6
TIME_SCOPE=5 TIME_SCOPE = 5
class AgentMemoryDataset(ABC):
PRONOUN=['','本人','在下','自己','','鄙人','','']
NAME='用户'
class AgentMemoryDataset(ABC):
PRONOUN = ['', '本人', '在下', '自己', '', '鄙人', '', '']
NAME = '用户'

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig from app.models import WorkflowConfig
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService from app.services.workflow_service import WorkflowService
from app.schemas import FileType
logger = get_business_logger() logger = get_business_logger()
@@ -43,18 +44,17 @@ class AppChatService:
message: str, message: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
config: AgentConfig, config: AgentConfig,
user_id: Optional[str] = None, files: list[FileInput],
user_id: str,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None
files: Optional[List[FileInput]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
start_time = time.time() start_time = time.time()
config_id = None
# 应用 features 配置 # 应用 features 配置
features_config: dict = config.features or {} features_config: dict = config.features or {}
@@ -93,7 +93,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
user_id)
tools.extend(kb_tools) tools.extend(kb_tools)
memory_flag = False memory_flag = False
if memory: if memory:
@@ -168,11 +169,6 @@ class AppChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
) )
@@ -229,6 +225,21 @@ class AppChatService:
# 保存消息 # 保存消息
if audio_url: if audio_url:
assistant_meta["audio_url"] = audio_url assistant_meta["audio_url"] = audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": result["content"]}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="user", role="user",
@@ -264,20 +275,19 @@ class AppChatService:
message: str, message: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
config: AgentConfig, config: AgentConfig,
files: list[FileInput],
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None
files: Optional[List[FileInput]] = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
try: try:
start_time = time.time() start_time = time.time()
config_id = None
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 应用 features 配置 # 应用 features 配置
@@ -319,7 +329,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
config.knowledge_retrieval, user_id)
tools.extend(kb_tools) tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
@@ -411,11 +422,6 @@ class AppChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files files=processed_files
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
@@ -459,7 +465,7 @@ class AppChatService:
# 保存消息 # 保存消息
human_meta = { human_meta = {
"files":[], "files": [],
"history_files": {} "history_files": {}
} }
assistant_meta = { assistant_meta = {
@@ -484,6 +490,22 @@ class AppChatService:
if stream_audio_url: if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url assistant_meta["audio_url"] = stream_audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": full_content}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="user", role="user",
@@ -618,7 +640,6 @@ class AppChatService:
# 2. 创建编排器 # 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config) orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务 # 3. 流式执行任务
async for event in orchestrator.execute_stream( async for event in orchestrator.execute_stream(
message=message, message=message,

View File

@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from app.schemas import FileType
logger = get_business_logger() logger = get_business_logger()
@@ -657,11 +656,6 @@ class AgentRunService:
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
) )
@@ -911,11 +905,6 @@ class AgentRunService:
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files files=processed_files
): ):
if isinstance(chunk, int): if isinstance(chunk, int):

View File

@@ -243,27 +243,6 @@ class MemoryPerceptualService:
memory_config: MemoryConfig, memory_config: MemoryConfig,
file: FileInput file: FileInput
): ):
memories = self.repository.get_by_url(file.url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return memory
else:
for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config) llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
multimodel_service = MultimodalService(self.db, ModelInfo( multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name, model_name=model_config.model_name,

View File

@@ -69,7 +69,8 @@ class ModelConfigService:
return items return items
@staticmethod @staticmethod
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: def get_model_by_name(db: Session, name: str, provider: str | None = None,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置""" """根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model: if not model:
@@ -77,21 +78,22 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
ModelConfig]:
"""按名称模糊匹配获取模型配置列表""" """按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@staticmethod @staticmethod
async def validate_model_config( async def validate_model_config(
db: Session, db: Session,
*, *,
model_name: str, model_name: str,
provider: str, provider: str,
api_key: str, api_key: str,
api_base: Optional[str] = None, api_base: Optional[str] = None,
model_type: str = "llm", model_type: str = "llm",
test_message: str = "Hello", test_message: str = "Hello",
is_omni: bool = False is_omni: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""验证模型配置是否有效 """验证模型配置是否有效
@@ -158,13 +160,13 @@ class ModelConfigService:
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态 # 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config) embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"] test_texts = [test_message, "测试文本"]
# 火山引擎使用 embed_batch其他使用 embed_documents # 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano": if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else: else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
@@ -200,11 +202,11 @@ class ModelConfigService:
}, },
"error": None "error": None
} }
elif model_type_lower == "image": elif model_type_lower == "image":
# 图片生成模型验证 # 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config) generator = RedBearImageGenerator(model_config)
result = await generator.agenerate( result = await generator.agenerate(
prompt="a cute panda", prompt="a cute panda",
@@ -212,7 +214,7 @@ class ModelConfigService:
) )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}") logger.info(f"成功生成图片,结果: {result}")
return { return {
"valid": True, "valid": True,
"message": "图片生成模型配置验证成功", "message": "图片生成模型配置验证成功",
@@ -224,21 +226,21 @@ class ModelConfigService:
}, },
"error": None "error": None
} }
elif model_type_lower == "video": elif model_type_lower == "video":
# 视频生成模型验证 # 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config) generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate( result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest", prompt="a cute panda playing in bamboo forest",
duration=5 duration=5
) )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID # 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None task_id = result.get("task_id") if isinstance(result, dict) else None
return { return {
"valid": True, "valid": True,
"message": "视频生成模型配置验证成功", "message": "视频生成模型配置验证成功",
@@ -265,7 +267,6 @@ class ModelConfigService:
# 提取详细的错误信息 # 提取详细的错误信息
error_message = str(e) error_message = str(e)
error_type = type(e).__name__ error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型 # 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商) # 区域/国家限制(适用于所有提供商)
@@ -354,14 +355,16 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置""" """更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model: if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name: if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
@@ -370,25 +373,27 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型""" """创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配 # 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key: if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
# 检查 API Key 关联的模型配置类型 # 检查 API Key 关联的模型配置类型
for model_config in api_key.model_configs: for model_config in api_key.model_configs:
# chat 和 llm 类型可以兼容 # chat 和 llm 类型可以兼容
compatible_types = {ModelType.LLM, ModelType.CHAT} compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type config_type = model_config.type
request_type = model_data.type request_type = model_data.type
if not (config_type == request_type or if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)): (config_type in compatible_types and request_type in compatible_types)):
raise BusinessException( raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
@@ -399,7 +404,7 @@ class ModelConfigService:
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
# BizCode.INVALID_PARAMETER # BizCode.INVALID_PARAMETER
# ) # )
# 创建组合模型 # 创建组合模型
model_config_data = { model_config_data = {
"tenant_id": tenant_id, "tenant_id": tenant_id,
@@ -418,49 +423,51 @@ class ModelConfigService:
model = ModelConfigRepository.create(db, model_config_data) model = ModelConfigRepository.create(db, model_config_data)
db.flush() db.flush()
# 关联 API Keys # 关联 API Keys
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key: if api_key:
model.api_keys.append(api_key) model.api_keys.append(api_key)
db.commit() db.commit()
db.refresh(model) db.refresh(model)
return model return model
@staticmethod @staticmethod
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""更新组合模型""" """更新组合模型"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model: if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name: if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite: if not existing_model.is_composite:
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
# 验证所有 API Key 存在且类型匹配 # 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key: if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
for model_config in api_key.model_configs: for model_config in api_key.model_configs:
compatible_types = {ModelType.LLM, ModelType.CHAT} compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type config_type = model_config.type
request_type = existing_model.type request_type = existing_model.type
if not (config_type == request_type or if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)): (config_type in compatible_types and request_type in compatible_types)):
raise BusinessException( raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
BizCode.INVALID_PARAMETER BizCode.INVALID_PARAMETER
) )
# 更新基本信息 # 更新基本信息
existing_model.name = model_data.name existing_model.name = model_data.name
# existing_model.type = model_data.type # existing_model.type = model_data.type
@@ -471,14 +478,14 @@ class ModelConfigService:
existing_model.is_public = model_data.is_public existing_model.is_public = model_data.is_public
if "load_balance_strategy" in model_data.model_fields_set: if "load_balance_strategy" in model_data.model_fields_set:
existing_model.load_balance_strategy = model_data.load_balance_strategy existing_model.load_balance_strategy = model_data.load_balance_strategy
# 更新 API Keys 关联 # 更新 API Keys 关联
existing_model.api_keys.clear() existing_model.api_keys.clear()
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key: if api_key:
existing_model.api_keys.append(api_key) existing_model.api_keys.append(api_key)
db.commit() db.commit()
db.refresh(existing_model) db.refresh(existing_model)
return existing_model return existing_model
@@ -532,7 +539,7 @@ class ModelApiKeyService:
"""根据provider为多个ModelConfig创建API Key""" """根据provider为多个ModelConfig创建API Key"""
created_keys = [] created_keys = []
failed_models = [] # 记录验证失败的模型 failed_models = [] # 记录验证失败的模型
for model_config_id in data.model_config_ids: for model_config_id in data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
@@ -540,10 +547,10 @@ class ModelApiKeyService:
data.is_omni = model_config.is_omni data.is_omni = model_config.is_omni
data.capability = model_config.capability data.capability = model_config.capability
# 从ModelBase获取model_name # 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name model_name = model_config.model_base.name if model_config.model_base else model_config.name
# 检查是否存在API Key包括软删除需要考虑tenant_id # 检查是否存在API Key包括软删除需要考虑tenant_id
existing_key = db.query(ModelApiKey).join( existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs ModelApiKey.model_configs
@@ -553,7 +560,7 @@ class ModelApiKeyService:
ModelApiKey.model_name == model_name, ModelApiKey.model_name == model_name,
ModelConfig.tenant_id == model_config.tenant_id ModelConfig.tenant_id == model_config.tenant_id
).first() ).first()
if existing_key: if existing_key:
# 如果已存在,重新激活并更新 # 如果已存在,重新激活并更新
if existing_key.is_active: if existing_key.is_active:
@@ -566,14 +573,14 @@ class ModelApiKeyService:
existing_key.model_name = model_name existing_key.model_name = model_name
existing_key.capability = data.capability existing_key.capability = data.capability
existing_key.is_omni = data.is_omni existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置 # 检查是否已关联该模型配置
if model_config not in existing_key.model_configs: if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config) existing_key.model_configs.append(model_config)
created_keys.append(existing_key) created_keys.append(existing_key)
continue continue
# 验证配置 # 验证配置
validation_result = await ModelConfigService.validate_model_config( validation_result = await ModelConfigService.validate_model_config(
db=db, db=db,
@@ -589,7 +596,7 @@ class ModelApiKeyService:
# 记录验证失败的模型,但不抛出异常 # 记录验证失败的模型,但不抛出异常
failed_models.append(model_name) failed_models.append(model_name)
continue continue
# 创建API Key # 创建API Key
api_key_data = ModelApiKeyCreate( api_key_data = ModelApiKeyCreate(
model_config_ids=[model_config_id], model_config_ids=[model_config_id],
@@ -606,12 +613,12 @@ class ModelApiKeyService:
) )
api_key_obj = ModelApiKeyRepository.create(db, api_key_data) api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
created_keys.append(api_key_obj) created_keys.append(api_key_obj)
if created_keys: if created_keys:
db.commit() db.commit()
for key in created_keys: for key in created_keys:
db.refresh(key) db.refresh(key)
return created_keys, failed_models return created_keys, failed_models
@staticmethod @staticmethod
@@ -626,7 +633,7 @@ class ModelApiKeyService:
api_key_data.is_omni = model_config.is_omni api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None: if api_key_data.capability is None:
api_key_data.capability = model_config.capability api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id # 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join( existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs ModelApiKey.model_configs
@@ -650,15 +657,15 @@ class ModelApiKeyService:
existing_key.model_name = api_key_data.model_name existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置 # 检查是否已关联该模型配置
if model_config not in existing_key.model_configs: if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config) existing_key.model_configs.append(model_config)
db.commit() db.commit()
db.refresh(existing_key) db.refresh(existing_key)
return existing_key return existing_key
# 验证配置 # 验证配置
validation_result = await ModelConfigService.validate_model_config( validation_result = await ModelConfigService.validate_model_config(
db=db, db=db,
@@ -691,7 +698,7 @@ class ModelApiKeyService:
# 获取关联的模型配置以获取模型类型 # 获取关联的模型配置以获取模型类型
if existing_api_key.model_configs: if existing_api_key.model_configs:
model_config = existing_api_key.model_configs[0] model_config = existing_api_key.model_configs[0]
validation_result = await ModelConfigService.validate_model_config( validation_result = await ModelConfigService.validate_model_config(
db=db, db=db,
model_name=api_key_data.model_name or existing_api_key.model_name, model_name=api_key_data.model_name or existing_api_key.model_name,
@@ -729,15 +736,15 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
return None return None
api_keys = [key for key in model_config.api_keys if key.is_active] api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys: if not api_keys:
return None return None
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个 # 否则返回第一个
return api_keys[0] return api_keys[0]
@@ -760,20 +767,19 @@ class ModelApiKeyService:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
class ModelBaseService: class ModelBaseService:
"""基础模型服务""" """基础模型服务"""
@staticmethod @staticmethod
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
models = ModelBaseRepository.get_list(db, query) models = ModelBaseRepository.get_list(db, query)
provider_groups = {} provider_groups = {}
for m in models: for m in models:
model_dict = model_schema.ModelBase.model_validate(m).model_dump() model_dict = model_schema.ModelBase.model_validate(m).model_dump()
if tenant_id: if tenant_id:
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
provider = m.provider provider = m.provider
if provider not in provider_groups: if provider not in provider_groups:
provider_groups[provider] = { provider_groups[provider] = {
@@ -781,7 +787,7 @@ class ModelBaseService:
"models": [] "models": []
} }
provider_groups[provider]["models"].append(model_dict) provider_groups[provider]["models"].append(model_dict)
return list(provider_groups.values()) return list(provider_groups.values())
@staticmethod @staticmethod
@@ -823,10 +829,10 @@ class ModelBaseService:
model_base = ModelBaseRepository.get_by_id(db, model_base_id) model_base = ModelBaseRepository.get_by_id(db, model_base_id)
if not model_base: if not model_base:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
model_config_data = { model_config_data = {
"model_id": model_base_id, "model_id": model_base_id,
"tenant_id": tenant_id, "tenant_id": tenant_id,

View File

@@ -1,26 +1,24 @@
"""基于分享链接的聊天服务""" """基于分享链接的聊天服务"""
import uuid
import time
import asyncio import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator from typing import Optional, Dict, Any, AsyncGenerator
from deprecated import deprecated
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.repositories.model_repository import ModelApiKeyRepository from app.core.error_codes import BizCode
from app.services.memory_konwledges_server import write_rag from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_business_logger
from app.models import MultiAgentConfig
from app.models import ReleaseShare, AppRelease, Conversation from app.models import ReleaseShare, AppRelease, Conversation
from app.repositories import knowledge_repository
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.multi_agent_service import MultiAgentService from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig from app.services.release_share_service import ReleaseShareService
from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger() logger = get_business_logger()
@@ -118,6 +116,7 @@ class SharedChatService:
return conversation return conversation
@deprecated("Use the chat method under app_chat_service instead.")
async def chat( async def chat(
self, self,
share_token: str, share_token: str,
@@ -136,10 +135,7 @@ class SharedChatService:
config_id = actual_config_id config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time() start_time = time.time()
actual_config_id = None actual_config_id = None
@@ -273,11 +269,6 @@ class SharedChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
) )
# 保存消息 # 保存消息
@@ -324,6 +315,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
@deprecated("Use the chat method under app_chat_service instead.")
async def chat_stream( async def chat_stream(
self, self,
share_token: str, share_token: str,
@@ -341,8 +333,6 @@ class SharedChatService:
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
import json import json
start_time = time.time() start_time = time.time()
@@ -486,11 +476,6 @@ class SharedChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
total_tokens = chunk total_tokens = chunk
@@ -585,6 +570,7 @@ class SharedChatService:
return conversations, total return conversations, total
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat( async def multi_agent_chat(
self, self,
share_token: str, share_token: str,
@@ -680,6 +666,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat_stream( async def multi_agent_chat_stream(
self, self,
share_token: str, share_token: str,