feat(agent, memory): add agent-perceived memory writing

This commit is contained in:
Eternity
2026-03-30 11:47:58 +08:00
parent a5bce221bd
commit 7acb7045f0
12 changed files with 304 additions and 530 deletions

View File

@@ -410,30 +410,6 @@ async def chat(
agent_config = agent_config_4_app_release(release) agent_config = agent_config_4_app_release(release)
if payload.stream: if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator(): async def event_generator():
async for event in app_chat_service.agnet_chat_stream( async for event in app_chat_service.agnet_chat_stream(
message=payload.message, message=payload.message,
@@ -459,20 +435,6 @@ async def chat(
"X-Accel-Buffering": "no" "X-Accel-Buffering": "no"
} }
) )
# 非流式返回
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat( result = await app_chat_service.agnet_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
@@ -531,48 +493,6 @@ async def chat(
) )
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW: elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release) config = workflow_config_4_app_release(release)
if not config.id: if not config.id:

View File

@@ -11,18 +11,14 @@ LangChain Agent 封装
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
logger = get_business_logger() logger = get_business_logger()
@@ -226,10 +222,9 @@ class LangChainAgent:
Returns: Returns:
List[BaseMessage]: 消息列表 List[BaseMessage]: 消息列表
""" """
messages = [] messages:list = [SystemMessage(content=self.system_prompt)]
# 添加系统提示词 # 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息 # 添加历史消息
if history: if history:
@@ -293,12 +288,7 @@ class LangChainAgent:
message: str, message: str,
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
end_user_id: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""执行对话 """执行对话
@@ -306,32 +296,12 @@ class LangChainAgent:
message: 用户消息 message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果) context: 上下文信息(如知识库检索结果)
files: 多模态文件
Returns: Returns:
Dict: 包含 content 和元数据的字典 Dict: 包含 content 和元数据的字典
""" """
message_chat = message
start_time = time.time() start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try: try:
# 准备消息列表(支持多模态) # 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files) messages = self._prepare_messages(message, history, context, files)
@@ -419,9 +389,6 @@ class LangChainAgent:
logger.info(f"最终提取的内容长度: {len(content)}") logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = { response = {
"content": content, "content": content,
"model": self.model_name, "model": self.model_name,
@@ -452,12 +419,7 @@ class LangChainAgent:
message: str, message: str,
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
end_user_id: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""执行流式对话 """执行流式对话
@@ -465,6 +427,7 @@ class LangChainAgent:
message: 用户消息 message: 用户消息
history: 历史消息列表 history: 历史消息列表
context: 上下文信息 context: 上下文信息
files: 多模态文件
Yields: Yields:
str: 消息内容块 str: 消息内容块
@@ -475,23 +438,6 @@ class LangChainAgent:
logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80) logger.info("=" * 80)
message_chat = message
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try: try:
# 准备消息列表(支持多模态) # 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files) messages = self._prepare_messages(message, history, context, files)
@@ -501,17 +447,18 @@ class LangChainAgent:
) )
chunk_count = 0 chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出 # 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出") logger.debug("使用 Agent astream_events 实现流式输出")
full_content = '' full_content = ''
try: try:
last_event = {}
async for event in self.agent.astream_events( async for event in self.agent.astream_events(
{"messages": messages}, {"messages": messages},
version="v2", version="v2",
config={"recursion_limit": self.max_iterations} config={"recursion_limit": self.max_iterations}
): ):
last_event = event
chunk_count += 1 chunk_count += 1
kind = event.get("event") kind = event.get("event")
@@ -525,7 +472,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content: if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content full_content += chunk_content
yield chunk_content yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list): elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分 # 多模态响应:提取文本部分
for item in chunk_content: for item in chunk_content:
@@ -536,18 +482,15 @@ class LangChainAgent:
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."} # OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text": elif item.get("type") == "text":
text = item.get("text", "") text = item.get("text", "")
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
elif isinstance(item, str): elif isinstance(item, str):
full_content += item full_content += item
yield item yield item
yielded_content = True
elif kind == "on_llm_stream": elif kind == "on_llm_stream":
# 另一种 LLM 流式事件 # 另一种 LLM 流式事件
@@ -558,7 +501,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content: if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content full_content += chunk_content
yield chunk_content yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list): elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分 # 多模态响应:提取文本部分
for item in chunk_content: for item in chunk_content:
@@ -569,22 +511,18 @@ class LangChainAgent:
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."} # OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text": elif item.get("type") == "text":
text = item.get("text", "") text = item.get("text", "")
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
elif isinstance(item, str): elif isinstance(item, str):
full_content += item full_content += item
yield item yield item
yielded_content = True
elif isinstance(chunk, str): elif isinstance(chunk, str):
full_content += chunk full_content += chunk
yield chunk yield chunk
yielded_content = True
# 记录工具调用(可选) # 记录工具调用(可选)
elif kind == "on_tool_start": elif kind == "on_tool_start":
@@ -594,7 +532,7 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗 # 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", []) output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
@@ -604,9 +542,7 @@ class LangChainAgent:
) if response_meta else 0 ) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e: except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise raise

View File

@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write( async def write(
storage_type, storage_type,
end_user_id, end_user_id,
@@ -118,7 +98,7 @@ async def write(
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): async def term_memory_save(end_user_id, strategy_type, scope):
""" """
Save long-term memory data to database Save long-term memory data to database
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
to long-term memory storage. to long-term memory storage.
Args: Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing scope: Scope/window size for memory processing
""" """
with get_db_context() as db_session: with get_db_context() as db_session:
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id) result = write_store.get_session_by_userid(end_user_id)
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: if not result:
logger.warning(f"No write data found for user {end_user_id}")
return
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
data = await format_parsing(result, "dict") data = await format_parsing(result, "dict")
chunk_data = data[:scope] chunk_data = data[:scope]
if len(chunk_data) == scope: if len(chunk_data) == scope:
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
logger.info(f'写入短长期:') logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
""" """
Process dialogue based on window size and write to Neo4j Process dialogue based on window size and write to Neo4j
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
langchain_messages: Original message data list langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage scope: Window size determining when to trigger long-term storage
""" """
scope = scope is_end_user_has_history = count_store.get_sessions_count(end_user_id)
is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_has_history:
if is_end_user_id is not False: end_user_visit_count, redis_messages = is_end_user_has_history
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] else:
redis_messages = count_store.get_sessions_count(end_user_id)[1] count_store.save_sessions_count(end_user_id, 1, langchain_messages)
if is_end_user_id and int(is_end_user_id) != int(scope): return
is_end_user_id += 1 end_user_visit_count += 1
langchain_messages += redis_messages if end_user_visit_count < scope:
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) redis_messages.extend(langchain_messages)
elif int(is_end_user_id) == int(scope): count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
else:
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = redis_messages redis_messages.extend(langchain_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id
else: else:
config_id = memory_config config_id = memory_config
await write( write_message_task.delay(
AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, # end_user_id: User ID
end_user_id, redis_messages, # message: JSON string format message list
"", config_id, # config_id: Configuration ID string
"", AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
None, "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
end_user_id,
config_id,
formatted_messages
) )
count_store.update_sessions_count(end_user_id, 1, langchain_messages) count_store.update_sessions_count(end_user_id, 0, [])
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time): async def memory_long_term_storage(end_user_id, memory_config, time):
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
return result_dict return result_dict
except Exception as e: except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}") logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
import traceback
traceback.print_exc()
return { return {
"is_same_event": False, "is_same_event": False,

View File

@@ -1,49 +1,25 @@
import asyncio
import json
import sys
import warnings import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
from app.db import get_db_context
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import write_rag
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def long_term_storage(
@asynccontextmanager long_term_type: str,
async def make_write_graph(): langchain_messages: list,
""" memory_config_id: str,
Create a write graph workflow for memory operations. end_user_id: str,
scope: int = 6
Args: ):
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
""" """
Handle long-term memory storage with different strategies Handle long-term memory storage with different strategies
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
Args: Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store langchain_messages: List of messages to store
memory_config: Memory configuration identifier memory_config_id: Memory configuration identifier
end_user_id: User group identifier end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6) scope: Scope parameter for chunk-based storage (default: 6)
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ if langchain_messages is None:
aggregate_judgment langchain_messages = []
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, langchain_messages) write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话 # 获取数据库会话
with get_db_context() as db_session: with get_db_context() as db_session:
config_service = MemoryConfigService(db_session) config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数 config_id=memory_config_id, # 改为整数
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation''' # Dialogue window with 6 rounds of conversation
await window_dialogue(end_user_id, langchain_messages, memory_config, scope) await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy""" # Time-based strategy
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment""" # Aggregate judgment
await aggregate_judgment(end_user_id, langchain_messages, memory_config) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): async def write_long_term(
storage_type: str,
end_user_id: str,
messages: list[dict],
user_rag_memory_id: str,
actual_config_id: str
):
""" """
Write long-term memory with different storage types Write long-term memory with different storage types
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
Args: Args:
storage_type: Type of storage (RAG or traditional) storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier end_user_id: User group identifier
message_chat: User message content messages: message list
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID actual_config_id: Actual configuration ID
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG: if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) message_content = []
for message in messages:
message_content.append(f'{message.get("role")}:{message.get("content")}')
messages_string = "\n".join(message_content)
await write_rag(end_user_id, messages_string, user_rag_memory_id)
else: else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once) # AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages) await long_term_storage(long_term_type=CHUNK,
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, langchain_messages=messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) memory_config_id=actual_config_id,
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) end_user_id=end_user_id,
scope=SCOPE)
# async def main(): await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -3,6 +3,7 @@ import uuid
from app.core.config import settings from app.core.config import settings
from typing import List, Dict, Any, Optional, Union from typing import List, Dict, Any, Optional, Union
from app.core.logging_config import get_logger
from app.core.memory.agent.utils.redis_base import ( from app.core.memory.agent.utils.redis_base import (
serialize_messages, serialize_messages,
deserialize_messages, deserialize_messages,
@@ -14,7 +15,7 @@ from app.core.memory.agent.utils.redis_base import (
get_current_timestamp get_current_timestamp
) )
logger = get_logger(__name__)
class RedisWriteStore: class RedisWriteStore:
@@ -66,10 +67,10 @@ class RedisWriteStore:
}) })
result = pipe.execute() result = pipe.execute()
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}") logger.error(f"[save_session_write] 保存会话失败: {e}")
raise e raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
@@ -112,10 +113,10 @@ class RedisWriteStore:
if not results: if not results:
return False return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results return results
except Exception as e: except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}") logger.error(f"[get_session_by_userid] 查询失败: {e}")
return False return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
@@ -144,7 +145,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key # 只查询 write 类型的 key
keys = self.r.keys('session:write:*') keys = self.r.keys('session:write:*')
if not keys: if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False return False
# 批量获取数据 # 批量获取数据
@@ -175,18 +176,16 @@ class RedisWriteStore:
results.append(session_info) results.append(session_info)
if not results: if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False return False
# 按时间排序(最新的在前) # 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True) results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results return results
except Exception as e: except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
import traceback
traceback.print_exc()
return False return False
def find_user_recent_sessions(self, userid: str, def find_user_recent_sessions(self, userid: str,
@@ -207,7 +206,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key # 只查询 write 类型的 key
keys = self.r.keys('session:write:*') keys = self.r.keys('session:write:*')
if not keys: if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return [] return []
# 批量获取数据 # 批量获取数据
@@ -234,11 +233,10 @@ class RedisWriteStore:
# 根据时间范围过滤 # 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes) filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段 # 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None) result_items = sort_and_limit_results(filtered_items)
print(result_items)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items return result_items
@@ -278,7 +276,7 @@ class RedisCountStore:
decode_responses=True, decode_responses=True,
encoding='utf-8' encoding='utf-8'
) )
self.uudi = session_id self.uuid = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
""" """
@@ -298,7 +296,7 @@ class RedisCountStore:
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, mapping={ pipe.hset(key, mapping={
"id": self.uudi, "id": self.uuid,
"end_user_id": end_user_id, "end_user_id": end_user_id,
"count": int(count), "count": int(count),
"messages": serialize_messages(messages), "messages": serialize_messages(messages),
@@ -311,10 +309,10 @@ class RedisCountStore:
result = pipe.execute() result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
""" """
通过 end_user_id 查询访问次数统计 通过 end_user_id 查询访问次数统计
@@ -335,7 +333,7 @@ class RedisCountStore:
self.r.delete(index_key) self.r.delete(index_key)
return False return False
except Exception as type_error: except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}") logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key) session_id = self.r.get(index_key)
@@ -355,15 +353,20 @@ class RedisCountStore:
messages_str = data.get('messages') messages_str = data.get('messages')
if count is not None: if count is not None:
messages = deserialize_messages(messages_str) messages: list[dict] = deserialize_messages(messages_str)
return [int(count), messages] return int(count), messages
return False return False
except Exception as e: except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}") logger.error(f"[get_sessions_count] 查询失败: {e}")
return False return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool: def update_sessions_count(
self,
end_user_id: str,
new_count: int,
messages: Any
) -> bool:
""" """
通过 end_user_id 修改访问次数统计(优化版:使用索引) 通过 end_user_id 修改访问次数统计(优化版:使用索引)
@@ -384,17 +387,17 @@ class RedisCountStore:
key_type = self.r.type(index_key) key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none': if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False # 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key) self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False return False
except Exception as type_error: except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}") logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key) session_id = self.r.get(index_key)
if not session_id: if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False return False
# 直接更新数据 # 直接更新数据
@@ -402,15 +405,15 @@ class RedisCountStore:
messages_str = serialize_messages(messages) messages_str = serialize_messages(messages)
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count)) pipe.hset(key, 'count', str(new_count))
pipe.hset(key, 'messages', messages_str) pipe.hset(key, 'messages', messages_str)
result = pipe.execute() result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True return True
except Exception as e: except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}") logger.debug(f"[update_sessions_count] 更新失败: {e}")
return False return False
def delete_all_count_sessions(self) -> int: def delete_all_count_sessions(self) -> int:
@@ -453,7 +456,7 @@ class RedisSessionStore:
# ==================== 写入操作 ==================== # ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str, def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str: apply_id: str, end_user_id: str) -> str:
""" """
写入一条会话数据,返回 session_id 写入一条会话数据,返回 session_id
@@ -483,10 +486,10 @@ class RedisSessionStore:
}) })
result = pipe.execute() result = pipe.execute()
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"[save_session] 保存会话失败: {e}") logger.error(f"[save_session] 保存会话失败: {e}")
raise e raise e
# ==================== 读取操作 ==================== # ==================== 读取操作 ====================
@@ -521,7 +524,7 @@ class RedisSessionStore:
return sessions return sessions
def find_user_apply_group(self, sessionid: str, apply_id: str, def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]: end_user_id: str) -> List[Dict[str, str]]:
""" """
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条 根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
@@ -538,7 +541,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*') keys = self.r.keys('session:*')
if not keys: if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return [] return []
# 批量获取数据 # 批量获取数据
@@ -556,7 +559,7 @@ class RedisSessionStore:
continue continue
if (data.get('apply_id') == apply_id and if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id): data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid # 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True)) matched_items.append(format_session_data(data, include_time=True))
@@ -565,7 +568,7 @@ class RedisSessionStore:
result_items = sort_and_limit_results(matched_items, limit=6) result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items return result_items
@@ -632,7 +635,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*') keys = self.r.keys('session:*')
if not keys: if not keys:
print("[delete_duplicate_sessions] 没有会话数据") logger.debug("[delete_duplicate_sessions] 没有会话数据")
return 0 return 0
# 批量获取所有数据 # 批量获取所有数据
@@ -678,7 +681,7 @@ class RedisSessionStore:
deleted_count += len(batch) deleted_count += len(batch)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}") logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count return deleted_count

View File

@@ -56,7 +56,7 @@ class LLMClient(ABC):
self.max_retries = self.config.max_retries self.max_retries = self.config.max_retries
self.timeout = self.config.timeout self.timeout = self.config.timeout
logger.info( logger.debug(
f"初始化 LLM 客户端: provider={self.provider}, " f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}" f"model={self.model_name}, max_retries={self.max_retries}"
) )

View File

@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
end_user_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC): class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量""" """长期记忆配置常量"""
STORAGE_NEO4J = "neo4j" STORAGE_NEO4J = "neo4j"
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_CHUNK = "chunk" STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time" STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6 DEFAULT_SCOPE = 6
TIME_SCOPE=5 TIME_SCOPE = 5
class AgentMemoryDataset(ABC):
PRONOUN=['','本人','在下','自己','','鄙人','','']
NAME='用户'
class AgentMemoryDataset(ABC):
PRONOUN = ['', '本人', '在下', '自己', '', '鄙人', '', '']
NAME = '用户'

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig from app.models import WorkflowConfig
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService from app.services.workflow_service import WorkflowService
from app.schemas import FileType
logger = get_business_logger() logger = get_business_logger()
@@ -43,18 +44,17 @@ class AppChatService:
message: str, message: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
config: AgentConfig, config: AgentConfig,
user_id: Optional[str] = None, files: list[FileInput],
user_id: str,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None
files: Optional[List[FileInput]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
start_time = time.time() start_time = time.time()
config_id = None
# 应用 features 配置 # 应用 features 配置
features_config: dict = config.features or {} features_config: dict = config.features or {}
@@ -93,7 +93,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
user_id)
tools.extend(kb_tools) tools.extend(kb_tools)
memory_flag = False memory_flag = False
if memory: if memory:
@@ -168,11 +169,6 @@ class AppChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
) )
@@ -229,6 +225,21 @@ class AppChatService:
# 保存消息 # 保存消息
if audio_url: if audio_url:
assistant_meta["audio_url"] = audio_url assistant_meta["audio_url"] = audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": result["content"]}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="user", role="user",
@@ -264,20 +275,19 @@ class AppChatService:
message: str, message: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
config: AgentConfig, config: AgentConfig,
files: list[FileInput],
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None
files: Optional[List[FileInput]] = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
try: try:
start_time = time.time() start_time = time.time()
config_id = None
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 应用 features 配置 # 应用 features 配置
@@ -319,7 +329,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
config.knowledge_retrieval, user_id)
tools.extend(kb_tools) tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
@@ -411,11 +422,6 @@ class AppChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files files=processed_files
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
@@ -459,7 +465,7 @@ class AppChatService:
# 保存消息 # 保存消息
human_meta = { human_meta = {
"files":[], "files": [],
"history_files": {} "history_files": {}
} }
assistant_meta = { assistant_meta = {
@@ -484,6 +490,22 @@ class AppChatService:
if stream_audio_url: if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url assistant_meta["audio_url"] = stream_audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": full_content}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="user", role="user",
@@ -618,7 +640,6 @@ class AppChatService:
# 2. 创建编排器 # 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config) orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务 # 3. 流式执行任务
async for event in orchestrator.execute_stream( async for event in orchestrator.execute_stream(
message=message, message=message,

View File

@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from app.schemas import FileType
logger = get_business_logger() logger = get_business_logger()
@@ -657,11 +656,6 @@ class AgentRunService:
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
) )
@@ -911,11 +905,6 @@ class AgentRunService:
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files files=processed_files
): ):
if isinstance(chunk, int): if isinstance(chunk, int):

View File

@@ -243,27 +243,6 @@ class MemoryPerceptualService:
memory_config: MemoryConfig, memory_config: MemoryConfig,
file: FileInput file: FileInput
): ):
memories = self.repository.get_by_url(file.url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return memory
else:
for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config) llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
multimodel_service = MultimodalService(self.db, ModelInfo( multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name, model_name=model_config.model_name,

View File

@@ -69,7 +69,8 @@ class ModelConfigService:
return items return items
@staticmethod @staticmethod
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: def get_model_by_name(db: Session, name: str, provider: str | None = None,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置""" """根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model: if not model:
@@ -77,21 +78,22 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
ModelConfig]:
"""按名称模糊匹配获取模型配置列表""" """按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@staticmethod @staticmethod
async def validate_model_config( async def validate_model_config(
db: Session, db: Session,
*, *,
model_name: str, model_name: str,
provider: str, provider: str,
api_key: str, api_key: str,
api_base: Optional[str] = None, api_base: Optional[str] = None,
model_type: str = "llm", model_type: str = "llm",
test_message: str = "Hello", test_message: str = "Hello",
is_omni: bool = False is_omni: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""验证模型配置是否有效 """验证模型配置是否有效
@@ -265,7 +267,6 @@ class ModelConfigService:
# 提取详细的错误信息 # 提取详细的错误信息
error_message = str(e) error_message = str(e)
error_type = type(e).__name__ error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型 # 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商) # 区域/国家限制(适用于所有提供商)
@@ -354,14 +355,16 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置""" """更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model: if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name: if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
@@ -370,9 +373,11 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型""" """创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配 # 验证所有 API Key 存在且类型匹配
@@ -430,14 +435,16 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""更新组合模型""" """更新组合模型"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model: if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name: if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite: if not existing_model.is_composite:
@@ -760,7 +767,6 @@ class ModelApiKeyService:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
class ModelBaseService: class ModelBaseService:
"""基础模型服务""" """基础模型服务"""

View File

@@ -1,26 +1,24 @@
"""基于分享链接的聊天服务""" """基于分享链接的聊天服务"""
import uuid
import time
import asyncio import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator from typing import Optional, Dict, Any, AsyncGenerator
from deprecated import deprecated
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.repositories.model_repository import ModelApiKeyRepository from app.core.error_codes import BizCode
from app.services.memory_konwledges_server import write_rag from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_business_logger
from app.models import MultiAgentConfig
from app.models import ReleaseShare, AppRelease, Conversation from app.models import ReleaseShare, AppRelease, Conversation
from app.repositories import knowledge_repository
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.multi_agent_service import MultiAgentService from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig from app.services.release_share_service import ReleaseShareService
from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger() logger = get_business_logger()
@@ -118,6 +116,7 @@ class SharedChatService:
return conversation return conversation
@deprecated("Use the chat method under app_chat_service instead.")
async def chat( async def chat(
self, self,
share_token: str, share_token: str,
@@ -136,10 +135,7 @@ class SharedChatService:
config_id = actual_config_id config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time() start_time = time.time()
actual_config_id = None actual_config_id = None
@@ -273,11 +269,6 @@ class SharedChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
) )
# 保存消息 # 保存消息
@@ -324,6 +315,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
@deprecated("Use the chat method under app_chat_service instead.")
async def chat_stream( async def chat_stream(
self, self,
share_token: str, share_token: str,
@@ -341,8 +333,6 @@ class SharedChatService:
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
import json import json
start_time = time.time() start_time = time.time()
@@ -486,11 +476,6 @@ class SharedChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
total_tokens = chunk total_tokens = chunk
@@ -585,6 +570,7 @@ class SharedChatService:
return conversations, total return conversations, total
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat( async def multi_agent_chat(
self, self,
share_token: str, share_token: str,
@@ -680,6 +666,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat_stream( async def multi_agent_chat_stream(
self, self,
share_token: str, share_token: str,