Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
# Conflicts: # api/app/core/agent/langchain_agent.py
This commit is contained in:
151
api/app/core/agent/agent_middleware.py
Normal file
151
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Agent Middleware - 动态技能过滤"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from app.services.skill_service import SkillService
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||
|
||||
def __init__(self, skill_ids: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化中间件
|
||||
|
||||
Args:
|
||||
skill_ids: 技能ID列表
|
||||
"""
|
||||
self.skill_ids = skill_ids or []
|
||||
|
||||
@staticmethod
|
||||
def filter_tools(
|
||||
tools: List,
|
||||
message: str = "",
|
||||
skill_configs: Dict[str, Any] = None,
|
||||
tool_to_skill_map: Dict[str, str] = None
|
||||
) -> tuple[List, List[str]]:
|
||||
"""
|
||||
根据消息内容和技能配置动态过滤工具
|
||||
|
||||
Args:
|
||||
tools: 所有可用工具列表
|
||||
message: 用户消息(可用于智能过滤)
|
||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||
|
||||
Returns:
|
||||
(过滤后的工具列表, 激活的技能ID列表)
|
||||
"""
|
||||
if not tools:
|
||||
return [], []
|
||||
|
||||
# 如果没有技能配置,返回所有工具
|
||||
if not skill_configs:
|
||||
return tools, []
|
||||
|
||||
# 基于关键词匹配激活技能
|
||||
activated_skill_ids = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill_id, config in skill_configs.items():
|
||||
if not config.get('enabled', True):
|
||||
continue
|
||||
|
||||
keywords = config.get('keywords', [])
|
||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||
activated_skill_ids.append(skill_id)
|
||||
|
||||
# 如果没有工具映射关系,返回所有工具
|
||||
if not tool_to_skill_map:
|
||||
return tools, activated_skill_ids
|
||||
|
||||
# 根据激活的技能过滤工具
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
return filtered_tools, activated_skill_ids
|
||||
|
||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
加载技能关联的工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tenant_id: 租户id
|
||||
base_tools: 基础工具列表
|
||||
|
||||
Returns:
|
||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||
"""
|
||||
|
||||
tools_dict = {}
|
||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||
|
||||
if base_tools:
|
||||
for tool in base_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
tools_dict[tool_name] = tool
|
||||
# base_tools 不属于任何 skill,不加入映射
|
||||
|
||||
skill_configs = {}
|
||||
|
||||
if self.skill_ids:
|
||||
for skill_id in self.skill_ids:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 保存技能配置(包含prompt)
|
||||
config = skill.config or {}
|
||||
config['prompt'] = skill.prompt
|
||||
config['name'] = skill.name
|
||||
skill_configs[skill_id] = config
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 加载技能工具并获取映射关系
|
||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
|
||||
|
||||
# 只添加不冲突的 skill_tools
|
||||
for tool in skill_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
if tool_name not in tools_dict:
|
||||
tools_dict[tool_name] = tool
|
||||
# 复制映射关系
|
||||
if tool_name in skill_tool_map:
|
||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||
|
||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||
|
||||
@staticmethod
|
||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
根据激活的技能ID获取对应的提示词
|
||||
|
||||
Args:
|
||||
activated_skill_ids: 被激活的技能ID列表
|
||||
skill_configs: 技能配置字典
|
||||
|
||||
Returns:
|
||||
合并后的提示词
|
||||
"""
|
||||
prompts = []
|
||||
for skill_id in activated_skill_ids:
|
||||
config = skill_configs.get(skill_id, {})
|
||||
prompt = config.get('prompt')
|
||||
name = config.get('name', 'Skill')
|
||||
if prompt:
|
||||
prompts.append(f"# {name}\n{prompt}")
|
||||
|
||||
return "\n\n".join(prompts) if prompts else ""
|
||||
|
||||
@staticmethod
|
||||
def create_runnable():
|
||||
"""创建可运行的中间件"""
|
||||
return RunnablePassthrough()
|
||||
@@ -291,6 +291,7 @@ class LangChainAgent:
|
||||
|
||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
||||
db = next(get_db())
|
||||
#TODO: 魔法数字
|
||||
scope=6
|
||||
|
||||
try:
|
||||
@@ -300,6 +301,12 @@ class LangChainAgent:
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
|
||||
# Handle case where no session exists in Redis (returns False)
|
||||
if not result or result is False:
|
||||
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
||||
return
|
||||
|
||||
if type=="chunk" or type=="aggregate":
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
@@ -307,7 +314,14 @@ class LangChainAgent:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'写入短长期:')
|
||||
else:
|
||||
# TODO: This branch handles type="time" strategy, currently unused.
|
||||
# Will be activated when time-based long-term storage is implemented.
|
||||
# TODO: 魔法数字 - extract 5 to a constant
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not long_time_data or long_time_data is False:
|
||||
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
||||
return
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
@@ -507,9 +521,12 @@ class LangChainAgent:
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
'''长期'''
|
||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
||||
response = {
|
||||
"content": content,
|
||||
@@ -693,9 +710,13 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -215,6 +215,9 @@ class Settings:
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# model square loading
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
# TODO:删除
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
print(contents)
|
||||
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
redis_messages = []
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not long_time_data or long_time_data is False:
|
||||
return
|
||||
format_messages = await chat_data_format(long_time_data)
|
||||
if format_messages!=[]:
|
||||
await write_messages(end_user_id, format_messages, memory_config)
|
||||
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not result or result is False:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
@@ -40,27 +36,55 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
"""Dispatch long-term memory storage to Celery background tasks.
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration ID (string)
|
||||
end_user_id: End user identifier
|
||||
scope: Window size for 'chunk' strategy (default: 6)
|
||||
"""
|
||||
from app.tasks import (
|
||||
long_term_storage_window_task,
|
||||
# TODO: Uncomment when implemented
|
||||
# long_term_storage_time_task,
|
||||
# long_term_storage_aggregate_task,
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Convert config to string if needed
|
||||
config_id = str(memory_config) if memory_config else ''
|
||||
|
||||
if long_term_type == 'chunk':
|
||||
# Strategy 1: Window-based batching (6 rounds of dialogue)
|
||||
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
|
||||
long_term_storage_window_task.delay(
|
||||
end_user_id=end_user_id,
|
||||
langchain_messages=langchain_messages,
|
||||
config_id=config_id,
|
||||
scope=scope
|
||||
)
|
||||
# TODO: Uncomment when time-based strategy is fully implemented
|
||||
# elif long_term_type == 'time':
|
||||
# # Strategy 2: Time-based retrieval
|
||||
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
|
||||
# long_term_storage_time_task.delay(
|
||||
# end_user_id=end_user_id,
|
||||
# config_id=config_id,
|
||||
# time_window=5
|
||||
# )
|
||||
# TODO: Uncomment when aggregate strategy is fully implemented
|
||||
# elif long_term_type == 'aggregate':
|
||||
# # Strategy 3: Aggregate judgment (deduplication)
|
||||
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
|
||||
# long_term_storage_aggregate_task.delay(
|
||||
# end_user_id=end_user_id,
|
||||
# langchain_messages=langchain_messages,
|
||||
# config_id=config_id
|
||||
# )
|
||||
|
||||
|
||||
# async def main():
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: bedrock
|
||||
enabled: false
|
||||
models:
|
||||
- name: ai21
|
||||
type: llm
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: dashscope
|
||||
enabled: false
|
||||
models:
|
||||
- name: deepseek-r1-distill-qwen-14b
|
||||
type: llm
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
|
||||
|
||||
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# 检查是否需要加载(默认为 true)
|
||||
if not data.get('enabled', True):
|
||||
return []
|
||||
|
||||
return data.get('models', [])
|
||||
|
||||
|
||||
def _disable_yaml_config(provider: ModelProvider) -> None:
|
||||
"""将YAML文件的enabled标志设置为false"""
|
||||
config_dir = Path(__file__).parent
|
||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
return
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
data['enabled'] = False
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||
"""
|
||||
加载模型配置到数据库
|
||||
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
|
||||
if not silent:
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
# provider_success = 0
|
||||
|
||||
for model_data in models:
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
else:
|
||||
# 创建新模型
|
||||
model = ModelBase(**model_data)
|
||||
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
if not silent:
|
||||
print(f"添加成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
if not silent:
|
||||
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||
result["failed"] += 1
|
||||
|
||||
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
||||
# if provider_success == len(models):
|
||||
_disable_yaml_config(provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: openai
|
||||
enabled: false
|
||||
models:
|
||||
- name: chatgpt-4o-latest
|
||||
type: llm
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
Reference in New Issue
Block a user