feat(agent, memory): add agent-perceived memory writing
This commit is contained in:
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -43,18 +44,17 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
files: list[FileInput],
|
||||
user_id: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
@@ -93,7 +93,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||
user_id)
|
||||
tools.extend(kb_tools)
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -168,11 +169,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -229,6 +225,21 @@ class AppChatService:
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": result["content"]}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -264,20 +275,19 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
files: list[FileInput],
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 应用 features 配置
|
||||
@@ -319,7 +329,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||
config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
@@ -411,11 +422,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
@@ -459,7 +465,7 @@ class AppChatService:
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[],
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
@@ -484,6 +490,22 @@ class AppChatService:
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": full_content}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -618,7 +640,6 @@ class AppChatService:
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
|
||||
Reference in New Issue
Block a user