Merge pull request #724 from SuanmoSuanyangTechnology/feature/memory-agent-perceptual
feat(agent, memory): add agent-perceived memory writing
This commit is contained in:
@@ -410,30 +410,6 @@ async def chat(
|
|||||||
agent_config = agent_config_4_app_release(release)
|
agent_config = agent_config_4_app_release(release)
|
||||||
|
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
# async def event_generator():
|
|
||||||
# async for event in service.chat_stream(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# ):
|
|
||||||
# yield event
|
|
||||||
|
|
||||||
# return StreamingResponse(
|
|
||||||
# event_generator(),
|
|
||||||
# media_type="text/event-stream",
|
|
||||||
# headers={
|
|
||||||
# "Cache-Control": "no-cache",
|
|
||||||
# "Connection": "keep-alive",
|
|
||||||
# "X-Accel-Buffering": "no"
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.agnet_chat_stream(
|
async for event in app_chat_service.agnet_chat_stream(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -459,20 +435,6 @@ async def chat(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# 非流式返回
|
|
||||||
# result = await service.chat(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# )
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
|
||||||
result = await app_chat_service.agnet_chat(
|
result = await app_chat_service.agnet_chat(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
@@ -531,48 +493,6 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
# 多 Agent 流式返回
|
|
||||||
# if payload.stream:
|
|
||||||
# async def event_generator():
|
|
||||||
# async for event in service.multi_agent_chat_stream(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# ):
|
|
||||||
# yield event
|
|
||||||
|
|
||||||
# return StreamingResponse(
|
|
||||||
# event_generator(),
|
|
||||||
# media_type="text/event-stream",
|
|
||||||
# headers={
|
|
||||||
# "Cache-Control": "no-cache",
|
|
||||||
# "Connection": "keep-alive",
|
|
||||||
# "X-Accel-Buffering": "no"
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 多 Agent 非流式返回
|
|
||||||
# result = await service.multi_agent_chat(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
config = workflow_config_4_app_release(release)
|
config = workflow_config_4_app_release(release)
|
||||||
if not config.id:
|
if not config.id:
|
||||||
|
|||||||
@@ -11,18 +11,14 @@ LangChain Agent 封装
|
|||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
|
||||||
from app.models.models_model import ModelType, ModelProvider
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.models.models_model import ModelType
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -226,10 +222,9 @@ class LangChainAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages = []
|
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||||
|
|
||||||
# 添加系统提示词
|
# 添加系统提示词
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
|
||||||
|
|
||||||
# 添加历史消息
|
# 添加历史消息
|
||||||
if history:
|
if history:
|
||||||
@@ -293,12 +288,7 @@ class LangChainAgent:
|
|||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
end_user_id: Optional[str] = None,
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
config_id: Optional[str] = None, # 添加这个参数
|
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
user_rag_memory_id: Optional[str] = None,
|
|
||||||
memory_flag: Optional[bool] = True,
|
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行对话
|
"""执行对话
|
||||||
|
|
||||||
@@ -306,32 +296,12 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||||
context: 上下文信息(如知识库检索结果)
|
context: 上下文信息(如知识库检索结果)
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含 content 和元数据的字典
|
Dict: 包含 content 和元数据的字典
|
||||||
"""
|
"""
|
||||||
message_chat = message
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
actual_config_id = config_id
|
|
||||||
# If config_id is None, try to get from end_user's connected config
|
|
||||||
if actual_config_id is None and end_user_id:
|
|
||||||
try:
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
|
||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
|
||||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
|
||||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
@@ -419,9 +389,6 @@ class LangChainAgent:
|
|||||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
|
||||||
actual_config_id)
|
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -452,12 +419,7 @@ class LangChainAgent:
|
|||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
end_user_id: Optional[str] = None,
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
config_id: Optional[str] = None,
|
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
user_rag_memory_id: Optional[str] = None,
|
|
||||||
memory_flag: Optional[bool] = True,
|
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
@@ -465,6 +427,7 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: 消息内容块
|
str: 消息内容块
|
||||||
@@ -475,23 +438,6 @@ class LangChainAgent:
|
|||||||
logger.info(f" Has tools: {bool(self.tools)}")
|
logger.info(f" Has tools: {bool(self.tools)}")
|
||||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
message_chat = message
|
|
||||||
actual_config_id = config_id
|
|
||||||
# If config_id is None, try to get from end_user's connected config
|
|
||||||
if actual_config_id is None and end_user_id:
|
|
||||||
try:
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
|
||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
@@ -501,17 +447,18 @@ class LangChainAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
yielded_content = False
|
|
||||||
|
|
||||||
# 统一使用 agent 的 astream_events 实现流式输出
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
full_content = ''
|
full_content = ''
|
||||||
try:
|
try:
|
||||||
|
last_event = {}
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
version="v2",
|
version="v2",
|
||||||
config={"recursion_limit": self.max_iterations}
|
config={"recursion_limit": self.max_iterations}
|
||||||
):
|
):
|
||||||
|
last_event = event
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
|
|
||||||
@@ -525,7 +472,6 @@ class LangChainAgent:
|
|||||||
if isinstance(chunk_content, str) and chunk_content:
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
full_content += chunk_content
|
full_content += chunk_content
|
||||||
yield chunk_content
|
yield chunk_content
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk_content, list):
|
elif isinstance(chunk_content, list):
|
||||||
# 多模态响应:提取文本部分
|
# 多模态响应:提取文本部分
|
||||||
for item in chunk_content:
|
for item in chunk_content:
|
||||||
@@ -536,18 +482,15 @@ class LangChainAgent:
|
|||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
elif item.get("type") == "text":
|
elif item.get("type") == "text":
|
||||||
text = item.get("text", "")
|
text = item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
full_content += item
|
full_content += item
|
||||||
yield item
|
yield item
|
||||||
yielded_content = True
|
|
||||||
|
|
||||||
elif kind == "on_llm_stream":
|
elif kind == "on_llm_stream":
|
||||||
# 另一种 LLM 流式事件
|
# 另一种 LLM 流式事件
|
||||||
@@ -558,7 +501,6 @@ class LangChainAgent:
|
|||||||
if isinstance(chunk_content, str) and chunk_content:
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
full_content += chunk_content
|
full_content += chunk_content
|
||||||
yield chunk_content
|
yield chunk_content
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk_content, list):
|
elif isinstance(chunk_content, list):
|
||||||
# 多模态响应:提取文本部分
|
# 多模态响应:提取文本部分
|
||||||
for item in chunk_content:
|
for item in chunk_content:
|
||||||
@@ -569,22 +511,18 @@ class LangChainAgent:
|
|||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
elif item.get("type") == "text":
|
elif item.get("type") == "text":
|
||||||
text = item.get("text", "")
|
text = item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
full_content += item
|
full_content += item
|
||||||
yield item
|
yield item
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
yielded_content = True
|
|
||||||
|
|
||||||
# 记录工具调用(可选)
|
# 记录工具调用(可选)
|
||||||
elif kind == "on_tool_start":
|
elif kind == "on_tool_start":
|
||||||
@@ -594,7 +532,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
# 统计token消耗
|
# 统计token消耗
|
||||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
@@ -604,9 +542,7 @@ class LangChainAgent:
|
|||||||
) if response_meta else 0
|
) if response_meta else 0
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
|
||||||
actual_config_id)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
from app.services.task_service import get_task_memory_write_result
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
|
|||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
|
||||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
|
||||||
"""
|
|
||||||
Write messages to RAG storage system
|
|
||||||
|
|
||||||
Combines user and AI messages into a single string format and stores them
|
|
||||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: User identifier for the conversation
|
|
||||||
user_message: User's input message content
|
|
||||||
ai_message: AI's response message content
|
|
||||||
user_rag_memory_id: RAG memory identifier for storage location
|
|
||||||
"""
|
|
||||||
# RAG mode: combine messages into string format (maintain original logic)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
storage_type,
|
storage_type,
|
||||||
end_user_id,
|
end_user_id,
|
||||||
@@ -118,7 +98,7 @@ async def write(
|
|||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
|
||||||
|
|
||||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||||
"""
|
"""
|
||||||
Save long-term memory data to database
|
Save long-term memory data to database
|
||||||
|
|
||||||
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
|||||||
to long-term memory storage.
|
to long-term memory storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
long_term_messages: Long-term message data to be saved
|
|
||||||
actual_config_id: Configuration identifier for memory settings
|
|
||||||
end_user_id: User identifier for memory association
|
end_user_id: User identifier for memory association
|
||||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||||
scope: Scope/window size for memory processing
|
scope: Scope/window size for memory processing
|
||||||
"""
|
"""
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
|||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if not result:
|
||||||
|
logger.warning(f"No write data found for user {end_user_id}")
|
||||||
|
return
|
||||||
|
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||||
data = await format_parsing(result, "dict")
|
data = await format_parsing(result, "dict")
|
||||||
chunk_data = data[:scope]
|
chunk_data = data[:scope]
|
||||||
if len(chunk_data) == scope:
|
if len(chunk_data) == scope:
|
||||||
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
|||||||
logger.info(f'写入短长期:')
|
logger.info(f'写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
"""Window-based dialogue processing"""
|
|
||||||
|
|
||||||
|
|
||||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||||
"""
|
"""
|
||||||
Process dialogue based on window size and write to Neo4j
|
Process dialogue based on window size and write to Neo4j
|
||||||
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
|||||||
langchain_messages: Original message data list
|
langchain_messages: Original message data list
|
||||||
scope: Window size determining when to trigger long-term storage
|
scope: Window size determining when to trigger long-term storage
|
||||||
"""
|
"""
|
||||||
scope = scope
|
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
if is_end_user_has_history:
|
||||||
if is_end_user_id is not False:
|
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
else:
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
return
|
||||||
is_end_user_id += 1
|
end_user_visit_count += 1
|
||||||
langchain_messages += redis_messages
|
if end_user_visit_count < scope:
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
redis_messages.extend(langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||||
|
else:
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
formatted_messages = redis_messages
|
redis_messages.extend(langchain_messages)
|
||||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
await write(
|
write_message_task.delay(
|
||||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
end_user_id, # end_user_id: User ID
|
||||||
end_user_id,
|
redis_messages, # message: JSON string format message list
|
||||||
"",
|
config_id, # config_id: Configuration ID string
|
||||||
"",
|
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||||
None,
|
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
end_user_id,
|
|
||||||
config_id,
|
|
||||||
formatted_messages
|
|
||||||
)
|
)
|
||||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.update_sessions_count(end_user_id, 0, [])
|
||||||
else:
|
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
|
||||||
|
|
||||||
|
|
||||||
"""Time-based memory processing"""
|
|
||||||
|
|
||||||
|
|
||||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||||
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"is_same_event": False,
|
"is_same_event": False,
|
||||||
|
|||||||
@@ -1,49 +1,25 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from langgraph.constants import END, START
|
|
||||||
from langgraph.graph import StateGraph
|
|
||||||
|
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
aggregate_judgment
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
from app.db import get_db_context
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
|
|
||||||
|
async def long_term_storage(
|
||||||
@asynccontextmanager
|
long_term_type: str,
|
||||||
async def make_write_graph():
|
langchain_messages: list,
|
||||||
"""
|
memory_config_id: str,
|
||||||
Create a write graph workflow for memory operations.
|
end_user_id: str,
|
||||||
|
scope: int = 6
|
||||||
Args:
|
):
|
||||||
user_id: User identifier
|
|
||||||
tools: MCP tools loaded from session
|
|
||||||
apply_id: Application identifier
|
|
||||||
end_user_id: Group identifier
|
|
||||||
memory_config: MemoryConfig object containing all configuration
|
|
||||||
"""
|
|
||||||
workflow = StateGraph(WriteState)
|
|
||||||
workflow.add_node("save_neo4j", write_node)
|
|
||||||
workflow.add_edge(START, "save_neo4j")
|
|
||||||
workflow.add_edge("save_neo4j", END)
|
|
||||||
|
|
||||||
graph = workflow.compile()
|
|
||||||
|
|
||||||
yield graph
|
|
||||||
|
|
||||||
|
|
||||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
|
||||||
end_user_id: str = '', scope: int = 6):
|
|
||||||
"""
|
"""
|
||||||
Handle long-term memory storage with different strategies
|
Handle long-term memory storage with different strategies
|
||||||
|
|
||||||
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
|
|||||||
Args:
|
Args:
|
||||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||||
langchain_messages: List of messages to store
|
langchain_messages: List of messages to store
|
||||||
memory_config: Memory configuration identifier
|
memory_config_id: Memory configuration identifier
|
||||||
end_user_id: User group identifier
|
end_user_id: User group identifier
|
||||||
scope: Scope parameter for chunk-based storage (default: 6)
|
scope: Scope parameter for chunk-based storage (default: 6)
|
||||||
"""
|
"""
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
if langchain_messages is None:
|
||||||
aggregate_judgment
|
langchain_messages = []
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
write_store.save_session_write(end_user_id, langchain_messages)
|
write_store.save_session_write(end_user_id, langchain_messages)
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=memory_config, # 改为整数
|
config_id=memory_config_id, # 改为整数
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
# Dialogue window with 6 rounds of conversation
|
||||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||||
"""Time-based strategy"""
|
# Time-based strategy
|
||||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
"""Strategy 3: Aggregate judgment"""
|
# Aggregate judgment
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
|
|
||||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
async def write_long_term(
|
||||||
|
storage_type: str,
|
||||||
|
end_user_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
user_rag_memory_id: str,
|
||||||
|
actual_config_id: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Write long-term memory with different storage types
|
Write long-term memory with different storage types
|
||||||
|
|
||||||
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
|
|||||||
Args:
|
Args:
|
||||||
storage_type: Type of storage (RAG or traditional)
|
storage_type: Type of storage (RAG or traditional)
|
||||||
end_user_id: User group identifier
|
end_user_id: User group identifier
|
||||||
message_chat: User message content
|
messages: message list
|
||||||
aimessages: AI response messages
|
|
||||||
user_rag_memory_id: RAG memory identifier
|
user_rag_memory_id: RAG memory identifier
|
||||||
actual_config_id: Actual configuration ID
|
actual_config_id: Actual configuration ID
|
||||||
"""
|
"""
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
|
||||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
message_content = []
|
||||||
|
for message in messages:
|
||||||
|
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||||
|
messages_string = "\n".join(message_content)
|
||||||
|
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||||
else:
|
else:
|
||||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
await long_term_storage(long_term_type=CHUNK,
|
||||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
langchain_messages=messages,
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
memory_config_id=actual_config_id,
|
||||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
end_user_id=end_user_id,
|
||||||
|
scope=SCOPE)
|
||||||
# async def main():
|
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||||
# """主函数 - 运行工作流"""
|
|
||||||
# langchain_messages = [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": "今天周五去爬山"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "好耶"
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# ]
|
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
|
||||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# import asyncio
|
|
||||||
# asyncio.run(main())
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import uuid
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from typing import List, Dict, Any, Optional, Union
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
from app.core.memory.agent.utils.redis_base import (
|
from app.core.memory.agent.utils.redis_base import (
|
||||||
serialize_messages,
|
serialize_messages,
|
||||||
deserialize_messages,
|
deserialize_messages,
|
||||||
@@ -14,7 +15,7 @@ from app.core.memory.agent.utils.redis_base import (
|
|||||||
get_current_timestamp
|
get_current_timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedisWriteStore:
|
class RedisWriteStore:
|
||||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session_write] 保存会话失败: {e}")
|
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||||
@@ -112,10 +113,10 @@ class RedisWriteStore:
|
|||||||
if not results:
|
if not results:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -175,18 +176,16 @@ class RedisWriteStore:
|
|||||||
results.append(session_info)
|
results.append(session_info)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 按时间排序(最新的在前)
|
# 按时间排序(最新的在前)
|
||||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def find_user_recent_sessions(self, userid: str,
|
def find_user_recent_sessions(self, userid: str,
|
||||||
@@ -207,7 +206,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -234,11 +233,10 @@ class RedisWriteStore:
|
|||||||
# 根据时间范围过滤
|
# 根据时间范围过滤
|
||||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||||
# 排序并移除时间字段
|
# 排序并移除时间字段
|
||||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
result_items = sort_and_limit_results(filtered_items)
|
||||||
print(result_items)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
|||||||
decode_responses=True,
|
decode_responses=True,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
self.uudi = session_id
|
self.uuid = session_id
|
||||||
|
|
||||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -298,7 +296,7 @@ class RedisCountStore:
|
|||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
"id": self.uudi,
|
"id": self.uuid,
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"count": int(count),
|
"count": int(count),
|
||||||
"messages": serialize_messages(messages),
|
"messages": serialize_messages(messages),
|
||||||
@@ -311,10 +309,10 @@ class RedisCountStore:
|
|||||||
|
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 查询访问次数统计
|
通过 end_user_id 查询访问次数统计
|
||||||
|
|
||||||
@@ -335,7 +333,7 @@ class RedisCountStore:
|
|||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
@@ -355,15 +353,20 @@ class RedisCountStore:
|
|||||||
messages_str = data.get('messages')
|
messages_str = data.get('messages')
|
||||||
|
|
||||||
if count is not None:
|
if count is not None:
|
||||||
messages = deserialize_messages(messages_str)
|
messages: list[dict] = deserialize_messages(messages_str)
|
||||||
return [int(count), messages]
|
return int(count), messages
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_sessions_count] 查询失败: {e}")
|
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
|
||||||
messages: Any) -> bool:
|
def update_sessions_count(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
new_count: int,
|
||||||
|
messages: Any
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||||
|
|
||||||
@@ -384,17 +387,17 @@ class RedisCountStore:
|
|||||||
key_type = self.r.type(index_key)
|
key_type = self.r.type(index_key)
|
||||||
if key_type != 'string' and key_type != 'none':
|
if key_type != 'string' and key_type != 'none':
|
||||||
# 索引键类型错误,删除并返回 False
|
# 索引键类型错误,删除并返回 False
|
||||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 直接更新数据
|
# 直接更新数据
|
||||||
@@ -402,15 +405,15 @@ class RedisCountStore:
|
|||||||
messages_str = serialize_messages(messages)
|
messages_str = serialize_messages(messages)
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, 'count', int(new_count))
|
pipe.hset(key, 'count', str(new_count))
|
||||||
pipe.hset(key, 'messages', messages_str)
|
pipe.hset(key, 'messages', messages_str)
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[update_sessions_count] 更新失败: {e}")
|
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_count_sessions(self) -> int:
|
def delete_all_count_sessions(self) -> int:
|
||||||
@@ -453,7 +456,7 @@ class RedisSessionStore:
|
|||||||
# ==================== 写入操作 ====================
|
# ==================== 写入操作 ====================
|
||||||
|
|
||||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||||
apply_id: str, end_user_id: str) -> str:
|
apply_id: str, end_user_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
|
|
||||||
@@ -483,10 +486,10 @@ class RedisSessionStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session] 保存会话失败: {e}")
|
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# ==================== 读取操作 ====================
|
# ==================== 读取操作 ====================
|
||||||
@@ -521,7 +524,7 @@ class RedisSessionStore:
|
|||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||||
end_user_id: str) -> List[Dict[str, str]]:
|
end_user_id: str) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||||
|
|
||||||
@@ -538,7 +541,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -556,7 +559,7 @@ class RedisSessionStore:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
if (data.get('apply_id') == apply_id and
|
||||||
data.get('end_user_id') == end_user_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
# 支持模糊匹配或完全匹配 sessionid
|
# 支持模糊匹配或完全匹配 sessionid
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
matched_items.append(format_session_data(data, include_time=True))
|
matched_items.append(format_session_data(data, include_time=True))
|
||||||
@@ -565,7 +568,7 @@ class RedisSessionStore:
|
|||||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
|
|
||||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print("[delete_duplicate_sessions] 没有会话数据")
|
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 批量获取所有数据
|
# 批量获取所有数据
|
||||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
|||||||
deleted_count += len(batch)
|
deleted_count += len(batch)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
|||||||
self.max_retries = self.config.max_retries
|
self.max_retries = self.config.max_retries
|
||||||
self.timeout = self.config.timeout
|
self.timeout = self.config.timeout
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
|
|||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentMemory_Long_Term(ABC):
|
class AgentMemory_Long_Term(ABC):
|
||||||
"""长期记忆配置常量"""
|
"""长期记忆配置常量"""
|
||||||
STORAGE_NEO4J = "neo4j"
|
STORAGE_NEO4J = "neo4j"
|
||||||
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
|
|||||||
STRATEGY_CHUNK = "chunk"
|
STRATEGY_CHUNK = "chunk"
|
||||||
STRATEGY_TIME = "time"
|
STRATEGY_TIME = "time"
|
||||||
DEFAULT_SCOPE = 6
|
DEFAULT_SCOPE = 6
|
||||||
TIME_SCOPE=5
|
TIME_SCOPE = 5
|
||||||
class AgentMemoryDataset(ABC):
|
|
||||||
PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余']
|
|
||||||
NAME='用户'
|
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMemoryDataset(ABC):
|
||||||
|
PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余']
|
||||||
|
NAME = '用户'
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||||
from app.models import WorkflowConfig
|
from app.models import WorkflowConfig
|
||||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
|||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
from app.schemas import FileType
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -43,18 +44,17 @@ class AppChatService:
|
|||||||
message: str,
|
message: str,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
user_id: Optional[str] = None,
|
files: list[FileInput],
|
||||||
|
user_id: str,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None
|
||||||
files: Optional[List[FileInput]] = None
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
|
||||||
|
|
||||||
# 应用 features 配置
|
# 应用 features 配置
|
||||||
features_config: dict = config.features or {}
|
features_config: dict = config.features or {}
|
||||||
@@ -93,7 +93,8 @@ class AppChatService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||||
|
user_id)
|
||||||
tools.extend(kb_tools)
|
tools.extend(kb_tools)
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
@@ -168,11 +169,6 @@ class AppChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -229,6 +225,21 @@ class AppChatService:
|
|||||||
# 保存消息
|
# 保存消息
|
||||||
if audio_url:
|
if audio_url:
|
||||||
assistant_meta["audio_url"] = audio_url
|
assistant_meta["audio_url"] = audio_url
|
||||||
|
if memory_flag:
|
||||||
|
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||||
|
memory_config_id: str = connected_config.get("memory_config_id")
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||||
|
{"role": "assistant", "content": result["content"]}
|
||||||
|
]
|
||||||
|
if memory_config_id:
|
||||||
|
await write_long_term(
|
||||||
|
storage_type,
|
||||||
|
user_id,
|
||||||
|
messages,
|
||||||
|
user_rag_memory_id,
|
||||||
|
memory_config_id
|
||||||
|
)
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
@@ -264,20 +275,19 @@ class AppChatService:
|
|||||||
message: str,
|
message: str,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
files: list[FileInput],
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None
|
||||||
files: Optional[List[FileInput]] = None
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
|
|
||||||
# 应用 features 配置
|
# 应用 features 配置
|
||||||
@@ -319,7 +329,8 @@ class AppChatService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||||
|
config.knowledge_retrieval, user_id)
|
||||||
tools.extend(kb_tools)
|
tools.extend(kb_tools)
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
@@ -411,11 +422,6 @@ class AppChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files
|
files=processed_files
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
@@ -459,7 +465,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
human_meta = {
|
human_meta = {
|
||||||
"files":[],
|
"files": [],
|
||||||
"history_files": {}
|
"history_files": {}
|
||||||
}
|
}
|
||||||
assistant_meta = {
|
assistant_meta = {
|
||||||
@@ -484,6 +490,22 @@ class AppChatService:
|
|||||||
|
|
||||||
if stream_audio_url:
|
if stream_audio_url:
|
||||||
assistant_meta["audio_url"] = stream_audio_url
|
assistant_meta["audio_url"] = stream_audio_url
|
||||||
|
|
||||||
|
if memory_flag:
|
||||||
|
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||||
|
memory_config_id: str = connected_config.get("memory_config_id")
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||||
|
{"role": "assistant", "content": full_content}
|
||||||
|
]
|
||||||
|
if memory_config_id:
|
||||||
|
await write_long_term(
|
||||||
|
storage_type,
|
||||||
|
user_id,
|
||||||
|
messages,
|
||||||
|
user_rag_memory_id,
|
||||||
|
memory_config_id
|
||||||
|
)
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
@@ -618,7 +640,6 @@ class AppChatService:
|
|||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
|
|
||||||
# 3. 流式执行任务
|
# 3. 流式执行任务
|
||||||
async for event in orchestrator.execute_stream(
|
async for event in orchestrator.execute_stream(
|
||||||
message=message,
|
message=message,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import AgentConfig, ModelConfig, ModelType
|
from app.models import AgentConfig, ModelConfig
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.app_schema import FileInput, Citation
|
from app.schemas.app_schema import FileInput, Citation
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
|||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.schemas import FileType
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -657,11 +656,6 @@ class AgentRunService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=context,
|
context=context,
|
||||||
end_user_id=user_id,
|
|
||||||
config_id=config_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -911,11 +905,6 @@ class AgentRunService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=context,
|
context=context,
|
||||||
end_user_id=user_id,
|
|
||||||
config_id=config_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files
|
files=processed_files
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
|
|||||||
@@ -243,27 +243,6 @@ class MemoryPerceptualService:
|
|||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
file: FileInput
|
file: FileInput
|
||||||
):
|
):
|
||||||
memories = self.repository.get_by_url(file.url)
|
|
||||||
if memories:
|
|
||||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
|
||||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
|
||||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
|
||||||
memory_cache = memories[0]
|
|
||||||
memory = self.repository.create_perceptual_memory(
|
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
|
||||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
|
||||||
file_path=memory_cache.file_path,
|
|
||||||
file_name=memory_cache.file_name,
|
|
||||||
file_ext=memory_cache.file_ext,
|
|
||||||
summary=memory_cache.summary,
|
|
||||||
meta_data=memory_cache.meta_data
|
|
||||||
)
|
|
||||||
self.db.commit()
|
|
||||||
return memory
|
|
||||||
else:
|
|
||||||
for memory in memories:
|
|
||||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
|
||||||
return memory
|
|
||||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||||
model_name=model_config.model_name,
|
model_name=model_config.model_name,
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ class ModelConfigService:
|
|||||||
return items
|
return items
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||||
|
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||||
"""根据名称获取模型配置"""
|
"""根据名称获取模型配置"""
|
||||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||||
if not model:
|
if not model:
|
||||||
@@ -77,21 +78,22 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||||
|
ModelConfig]:
|
||||||
"""按名称模糊匹配获取模型配置列表"""
|
"""按名称模糊匹配获取模型配置列表"""
|
||||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def validate_model_config(
|
async def validate_model_config(
|
||||||
db: Session,
|
db: Session,
|
||||||
*,
|
*,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
model_type: str = "llm",
|
model_type: str = "llm",
|
||||||
test_message: str = "Hello",
|
test_message: str = "Hello",
|
||||||
is_omni: bool = False
|
is_omni: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""验证模型配置是否有效
|
"""验证模型配置是否有效
|
||||||
|
|
||||||
@@ -265,7 +267,6 @@ class ModelConfigService:
|
|||||||
# 提取详细的错误信息
|
# 提取详细的错误信息
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
error_type = type(e).__name__
|
error_type = type(e).__name__
|
||||||
print("=========error_message:",error_message.lower())
|
|
||||||
# 特殊处理常见的错误类型
|
# 特殊处理常见的错误类型
|
||||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||||
# 区域/国家限制(适用于所有提供商)
|
# 区域/国家限制(适用于所有提供商)
|
||||||
@@ -354,14 +355,16 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||||
|
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||||
"""更新模型配置"""
|
"""更新模型配置"""
|
||||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||||
if not existing_model:
|
if not existing_model:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
if model_data.name and model_data.name != existing_model.name:
|
if model_data.name and model_data.name != existing_model.name:
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||||
|
tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||||
@@ -370,9 +373,11 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||||
|
tenant_id: uuid.UUID) -> ModelConfig:
|
||||||
"""创建组合模型"""
|
"""创建组合模型"""
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||||
|
tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
# 验证所有 API Key 存在且类型匹配
|
# 验证所有 API Key 存在且类型匹配
|
||||||
@@ -430,14 +435,16 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||||
|
tenant_id: uuid.UUID) -> ModelConfig:
|
||||||
"""更新组合模型"""
|
"""更新组合模型"""
|
||||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||||
if not existing_model:
|
if not existing_model:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
if model_data.name and model_data.name != existing_model.name:
|
if model_data.name and model_data.name != existing_model.name:
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||||
|
tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
if not existing_model.is_composite:
|
if not existing_model.is_composite:
|
||||||
@@ -760,7 +767,6 @@ class ModelApiKeyService:
|
|||||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBaseService:
|
class ModelBaseService:
|
||||||
"""基础模型服务"""
|
"""基础模型服务"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +1,24 @@
|
|||||||
"""基于分享链接的聊天服务"""
|
"""基于分享链接的聊天服务"""
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Optional, Dict, Any, AsyncGenerator
|
from typing import Optional, Dict, Any, AsyncGenerator
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
from app.core.error_codes import BizCode
|
||||||
from app.services.memory_konwledges_server import write_rag
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.models import MultiAgentConfig
|
||||||
from app.models import ReleaseShare, AppRelease, Conversation
|
from app.models import ReleaseShare, AppRelease, Conversation
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.draft_run_service import create_web_search_tool
|
from app.services.draft_run_service import create_web_search_tool
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
|
||||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
from app.models import MultiAgentConfig
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.repositories import knowledge_repository
|
|
||||||
import json
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -118,6 +116,7 @@ class SharedChatService:
|
|||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
@@ -136,10 +135,7 @@ class SharedChatService:
|
|||||||
config_id = actual_config_id
|
config_id = actual_config_id
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from sqlalchemy import select
|
|
||||||
from app.models import ModelApiKey
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
actual_config_id = None
|
actual_config_id = None
|
||||||
@@ -273,11 +269,6 @@ class SharedChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
@@ -324,6 +315,7 @@ class SharedChatService:
|
|||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
@@ -341,8 +333,6 @@ class SharedChatService:
|
|||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from sqlalchemy import select
|
|
||||||
from app.models import ModelApiKey
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -486,11 +476,6 @@ class SharedChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag
|
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
@@ -585,6 +570,7 @@ class SharedChatService:
|
|||||||
|
|
||||||
return conversations, total
|
return conversations, total
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def multi_agent_chat(
|
async def multi_agent_chat(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
@@ -680,6 +666,7 @@ class SharedChatService:
|
|||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def multi_agent_chat_stream(
|
async def multi_agent_chat_stream(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user