Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

# Conflicts:
#	api/app/core/agent/langchain_agent.py
This commit is contained in:
Mark
2026-02-04 15:51:44 +08:00
46 changed files with 1219 additions and 117 deletions

View File

@@ -0,0 +1,151 @@
"""Agent Middleware - 动态技能过滤"""
import uuid
from typing import List, Dict, Any, Optional
from langchain_core.runnables import RunnablePassthrough
from app.services.skill_service import SkillService
from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skill_ids: Optional[List[str]] = None):
"""
初始化中间件
Args:
skill_ids: 技能ID列表
"""
self.skill_ids = skill_ids or []
@staticmethod
def filter_tools(
tools: List,
message: str = "",
skill_configs: Dict[str, Any] = None,
tool_to_skill_map: Dict[str, str] = None
) -> tuple[List, List[str]]:
"""
根据消息内容和技能配置动态过滤工具
Args:
tools: 所有可用工具列表
message: 用户消息(可用于智能过滤)
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
Returns:
(过滤后的工具列表, 激活的技能ID列表)
"""
if not tools:
return [], []
# 如果没有技能配置,返回所有工具
if not skill_configs:
return tools, []
# 基于关键词匹配激活技能
activated_skill_ids = []
message_lower = message.lower()
for skill_id, config in skill_configs.items():
if not config.get('enabled', True):
continue
keywords = config.get('keywords', [])
# 如果没有关键词限制,或消息包含关键词,则激活该技能
if not keywords or any(kw.lower() in message_lower for kw in keywords):
activated_skill_ids.append(skill_id)
# 如果没有工具映射关系,返回所有工具
if not tool_to_skill_map:
return tools, activated_skill_ids
# 根据激活的技能过滤工具
filtered_tools = []
for tool in tools:
tool_name = getattr(tool, 'name', str(id(tool)))
# 如果工具不属于任何skillbase_tools或者工具所属的skill被激活则保留
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
filtered_tools.append(tool)
return filtered_tools, activated_skill_ids
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
"""
加载技能关联的工具
Args:
db: 数据库会话
tenant_id: 租户id
base_tools: 基础工具列表
Returns:
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
"""
tools_dict = {}
tool_to_skill_map = {} # 工具名称到技能ID的映射
if base_tools:
for tool in base_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
tools_dict[tool_name] = tool
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
if self.skill_ids:
for skill_id in self.skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 保存技能配置包含prompt
config = skill.config or {}
config['prompt'] = skill.prompt
config['name'] = skill.name
skill_configs[skill_id] = config
except Exception:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, self.skill_ids, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
if tool_name not in tools_dict:
tools_dict[tool_name] = tool
# 复制映射关系
if tool_name in skill_tool_map:
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
return list(tools_dict.values()), skill_configs, tool_to_skill_map
@staticmethod
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
"""
根据激活的技能ID获取对应的提示词
Args:
activated_skill_ids: 被激活的技能ID列表
skill_configs: 技能配置字典
Returns:
合并后的提示词
"""
prompts = []
for skill_id in activated_skill_ids:
config = skill_configs.get(skill_id, {})
prompt = config.get('prompt')
name = config.get('name', 'Skill')
if prompt:
prompts.append(f"# {name}\n{prompt}")
return "\n\n".join(prompts) if prompts else ""
@staticmethod
def create_runnable():
"""创建可运行的中间件"""
return RunnablePassthrough()

View File

@@ -291,6 +291,7 @@ class LangChainAgent:
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
@@ -300,6 +301,12 @@ class LangChainAgent:
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
@@ -307,7 +314,14 @@ class LangChainAgent:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
@@ -507,9 +521,12 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
'''长期'''
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
@@ -693,9 +710,13 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e: