feat(agent, memory): add agent-perceived memory writing
This commit is contained in:
@@ -410,30 +410,6 @@ async def chat(
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
@@ -459,20 +435,6 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
@@ -531,48 +493,6 @@ async def chat(
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
|
||||
@@ -11,18 +11,14 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -226,10 +222,9 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||
|
||||
# 添加系统提示词
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
@@ -293,12 +288,7 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -306,32 +296,12 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
context: 上下文信息(如知识库检索结果)
|
||||
files: 多模态文件
|
||||
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -419,9 +389,6 @@ class LangChainAgent:
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -452,12 +419,7 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -465,6 +427,7 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: 消息内容块
|
||||
@@ -475,23 +438,6 @@ class LangChainAgent:
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -501,17 +447,18 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
yielded_content = False
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content = ''
|
||||
try:
|
||||
last_event = {}
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
last_event = event
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
@@ -525,7 +472,6 @@ class LangChainAgent:
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -536,18 +482,15 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
@@ -558,7 +501,6 @@ class LangChainAgent:
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -569,22 +511,18 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
@@ -594,7 +532,7 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
@@ -604,9 +542,7 @@ class LangChainAgent:
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
@@ -118,7 +98,7 @@ async def write(
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if not result:
|
||||
logger.warning(f"No write data found for user {end_user_id}")
|
||||
return
|
||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data) == scope:
|
||||
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_has_history:
|
||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
return
|
||||
end_user_visit_count += 1
|
||||
if end_user_visit_count < scope:
|
||||
redis_messages.extend(langchain_messages)
|
||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||
else:
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = redis_messages
|
||||
redis_messages.extend(langchain_messages)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
write_message_task.delay(
|
||||
end_user_id, # end_user_id: User ID
|
||||
redis_messages, # message: JSON string format message list
|
||||
config_id, # config_id: Configuration ID string
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""Time-based memory processing"""
|
||||
count_store.update_sessions_count(end_user_id, 0, [])
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
|
||||
@@ -1,49 +1,25 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
async def long_term_storage(
|
||||
long_term_type: str,
|
||||
langchain_messages: list,
|
||||
memory_config_id: str,
|
||||
end_user_id: str,
|
||||
scope: int = 6
|
||||
):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
memory_config_id: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
if langchain_messages is None:
|
||||
langchain_messages = []
|
||||
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
config_id=memory_config_id, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
# Dialogue window with 6 rounds of conversation
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
# Time-based strategy
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
# Aggregate judgment
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
async def write_long_term(
|
||||
storage_type: str,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
user_rag_memory_id: str,
|
||||
actual_config_id: str
|
||||
):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
messages: message list
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
message_content = []
|
||||
for message in messages:
|
||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||
messages_string = "\n".join(message_content)
|
||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||
else:
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
await long_term_storage(long_term_type=CHUNK,
|
||||
langchain_messages=messages,
|
||||
memory_config_id=actual_config_id,
|
||||
end_user_id=end_user_id,
|
||||
scope=SCOPE)
|
||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
@@ -3,8 +3,9 @@ import uuid
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
@@ -99,7 +100,7 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
@@ -108,16 +109,16 @@ class RedisWriteStore:
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
@@ -158,12 +159,12 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
@@ -173,23 +174,21 @@ class RedisWriteStore:
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
@@ -203,11 +202,11 @@ class RedisWriteStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -221,7 +220,7 @@ class RedisWriteStore:
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
@@ -230,15 +229,14 @@ class RedisWriteStore:
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
result_items = sort_and_limit_results(filtered_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
@@ -258,7 +256,7 @@ class RedisWriteStore:
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
self.uuid = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
@@ -295,26 +293,26 @@ class RedisCountStore:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"id": self.uuid,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
|
||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
@@ -327,7 +325,7 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
@@ -335,35 +333,40 @@ class RedisCountStore:
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
messages: list[dict] = deserialize_messages(messages_str)
|
||||
return int(count), messages
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
|
||||
def update_sessions_count(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_count: int,
|
||||
messages: Any
|
||||
) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
@@ -378,39 +381,39 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'count', str(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
|
||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
@@ -428,7 +431,7 @@ class RedisCountStore:
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -451,9 +454,9 @@ class RedisSessionStore:
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
@@ -483,14 +486,14 @@ class RedisSessionStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
@@ -520,8 +523,8 @@ class RedisSessionStore:
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
@@ -535,10 +538,10 @@ class RedisSessionStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -556,21 +559,21 @@ class RedisSessionStore:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
@@ -591,7 +594,7 @@ class RedisSessionStore:
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
||||
self.max_retries = self.config.max_retries
|
||||
self.timeout = self.config.timeout
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
|
||||
end_user_id: str
|
||||
config_id: Optional[str] = None
|
||||
|
||||
|
||||
class AgentMemory_Long_Term(ABC):
|
||||
"""长期记忆配置常量"""
|
||||
STORAGE_NEO4J = "neo4j"
|
||||
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
|
||||
STRATEGY_CHUNK = "chunk"
|
||||
STRATEGY_TIME = "time"
|
||||
DEFAULT_SCOPE = 6
|
||||
TIME_SCOPE=5
|
||||
class AgentMemoryDataset(ABC):
|
||||
PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余']
|
||||
NAME='用户'
|
||||
TIME_SCOPE = 5
|
||||
|
||||
|
||||
class AgentMemoryDataset(ABC):
|
||||
PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余']
|
||||
NAME = '用户'
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -43,18 +44,17 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
files: list[FileInput],
|
||||
user_id: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
@@ -93,7 +93,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||
user_id)
|
||||
tools.extend(kb_tools)
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -168,11 +169,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -229,6 +225,21 @@ class AppChatService:
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": result["content"]}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -264,20 +275,19 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
files: list[FileInput],
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 应用 features 配置
|
||||
@@ -319,7 +329,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||
config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
@@ -411,11 +422,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
@@ -459,7 +465,7 @@ class AppChatService:
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[],
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
@@ -484,6 +490,22 @@ class AppChatService:
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": full_content}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -618,7 +640,6 @@ class AppChatService:
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
|
||||
@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -657,11 +656,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -911,11 +905,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
|
||||
@@ -243,27 +243,6 @@ class MemoryPerceptualService:
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
|
||||
@@ -69,7 +69,8 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
@@ -77,21 +78,22 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||
ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||
|
||||
@staticmethod
|
||||
async def validate_model_config(
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -158,13 +160,13 @@ class ModelConfigService:
|
||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
|
||||
|
||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||
if provider.lower() == "volcano":
|
||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||
else:
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -200,11 +202,11 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "image":
|
||||
# 图片生成模型验证
|
||||
from app.core.models.generation import RedBearImageGenerator
|
||||
|
||||
|
||||
generator = RedBearImageGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda",
|
||||
@@ -212,7 +214,7 @@ class ModelConfigService:
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"成功生成图片,结果: {result}")
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "图片生成模型配置验证成功",
|
||||
@@ -224,21 +226,21 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "video":
|
||||
# 视频生成模型验证
|
||||
from app.core.models.generation import RedBearVideoGenerator
|
||||
|
||||
|
||||
generator = RedBearVideoGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda playing in bamboo forest",
|
||||
duration=5
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 视频生成是异步任务,返回任务ID
|
||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "视频生成模型配置验证成功",
|
||||
@@ -265,7 +267,6 @@ class ModelConfigService:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
print("=========error_message:",error_message.lower())
|
||||
# 特殊处理常见的错误类型
|
||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||
# 区域/国家限制(适用于所有提供商)
|
||||
@@ -354,14 +355,16 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -370,25 +373,27 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
@@ -399,7 +404,7 @@ class ModelConfigService:
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
@@ -418,49 +423,51 @@ class ModelConfigService:
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = existing_model.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
# existing_model.type = model_data.type
|
||||
@@ -471,14 +478,14 @@ class ModelConfigService:
|
||||
existing_model.is_public = model_data.is_public
|
||||
if "load_balance_strategy" in model_data.model_fields_set:
|
||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
@@ -532,7 +539,7 @@ class ModelApiKeyService:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
failed_models = [] # 记录验证失败的模型
|
||||
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
@@ -540,10 +547,10 @@ class ModelApiKeyService:
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -553,7 +560,7 @@ class ModelApiKeyService:
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
@@ -566,14 +573,14 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -589,7 +596,7 @@ class ModelApiKeyService:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
failed_models.append(model_name)
|
||||
continue
|
||||
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
@@ -606,12 +613,12 @@ class ModelApiKeyService:
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
|
||||
return created_keys, failed_models
|
||||
|
||||
@staticmethod
|
||||
@@ -626,7 +633,7 @@ class ModelApiKeyService:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -650,15 +657,15 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -691,7 +698,7 @@ class ModelApiKeyService:
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
@@ -729,15 +736,15 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@@ -760,20 +767,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
@@ -781,7 +787,7 @@ class ModelBaseService:
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
@@ -823,10 +829,10 @@ class ModelBaseService:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
"""基于分享链接的聊天服务"""
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import MultiAgentConfig
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -118,6 +116,7 @@ class SharedChatService:
|
||||
|
||||
return conversation
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -136,10 +135,7 @@ class SharedChatService:
|
||||
config_id = actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
@@ -273,11 +269,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -324,6 +315,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -341,8 +333,6 @@ class SharedChatService:
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
@@ -486,11 +476,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
@@ -585,6 +570,7 @@ class SharedChatService:
|
||||
|
||||
return conversations, total
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -680,6 +666,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
|
||||
Reference in New Issue
Block a user