Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-01-12 21:11:09 +08:00
38 changed files with 811 additions and 296 deletions

View File

@@ -70,24 +70,25 @@ async def trigger_forgetting_cycle(
ApiResponse: 包含遗忘报告的响应
"""
workspace_id = current_user.current_workspace_id
end_user_id = payload.end_user_id # 从 payload 中获取 end_user_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 通过 group_id 获取关联的 config_id
# 通过 end_user_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(payload.group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {payload.group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {payload.group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={payload.group_id} 获取到 config_id={config_id}")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -97,7 +98,7 @@ async def trigger_forgetting_cycle(
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
f"group_id={payload.group_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
f"end_user_id={end_user_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
f"min_days={payload.min_days_since_access}"
)
@@ -105,7 +106,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
group_id=payload.group_id,
group_id=end_user_id, # 服务层方法的参数名是 group_id
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=config_id

View File

@@ -641,6 +641,7 @@ class AccessHistoryManager:
n.access_count = $access_count,
n.version = $new_version
RETURN n.id as id,
n.statement as statement,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,

View File

@@ -54,9 +54,9 @@ async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
Returns:
符合反思范围的记忆数据列表。
"""
if REFLEXION_RANGE == "retrieval":
if REFLEXION_RANGE == "partial":
return await get_data(host_id)
elif REFLEXION_RANGE == "database":
elif REFLEXION_RANGE == "all":
return []
else:
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")

View File

@@ -260,8 +260,7 @@ class ConditionBase(ABC):
raise RuntimeError("Unsupported variable type")
def check(self, no_right=False):
left = self.pool.get(self.left_selector.variable_selector)
if not isinstance(left, self.type_limit):
if not isinstance(self.left_value, self.type_limit):
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
if not no_right:
right = self.resolve_right_literal_value()

View File

@@ -37,7 +37,7 @@ class ParamsConfig(BaseModel):
)
required: bool = Field(
...,
default=False,
description="Whether the parameter is required"
)
@@ -59,6 +59,6 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
)
prompt: str = Field(
...,
default="",
description="User-provided supplemental prompt"
)

View File

@@ -157,9 +157,17 @@ class ParameterExtractorNode(BaseNode):
messages = [
("system", system_prompt),
("user", self._render_template(self.typed_config.prompt, state)),
("user", rendered_user_prompt),
]
if self.typed_config.prompt:
messages.extend([
("user", self._render_template(self.typed_config.prompt, state)),
("user", rendered_user_prompt),
])
else:
messages.extend([
("user", rendered_user_prompt),
])
model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content, return_objects=True)

View File

@@ -727,6 +727,7 @@ SET m += {
dialog_id: summary.dialog_id,
chunk_ids: summary.chunk_ids,
content: summary.content,
memory_type: summary.memory_type,
summary_embedding: summary.summary_embedding,
config_id: summary.config_id,
importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END,

View File

@@ -173,9 +173,10 @@ class MemoryConfigValidation(BaseModel):
chunker_strategy: str = Field(default="RecursiveChunker", min_length=1, max_length=100)
reflexion_enabled: bool = Field(default=False)
reflexion_iteration_period: int = Field(default=3, ge=1, le=100)
reflexion_range: Literal["retrieval", "all"] = Field(default="retrieval")
reflexion_baseline: Literal["time", "fact", "time_and_fact"] = Field(default="time")
reflexion_range: Literal["partial", "all"] = Field(default="partial")
reflexion_baseline: Literal["TIME", "FACT", "HYBRID"] = Field(default="TIME")
llm_params: Dict[str, Any] = Field(default_factory=dict)
embedding_params: Dict[str, Any] = Field(default_factory=dict)
config_version: str = Field(default="2.0", min_length=1, max_length=10)

View File

@@ -292,8 +292,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(
"3", description="反思迭代周期,单位小时"
)
reflexion_range: Optional[Literal["retrieval", "database"]] = Field(
"retrieval", description="反思范围:部分/全部"
reflexion_range: Optional[Literal["partial", "all"]] = Field(
"partial", description="反思范围:部分/全部"
)
baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field(
"TIME", description="基线:时间/事实/时间和事实"
@@ -409,7 +409,7 @@ class ForgettingTriggerRequest(BaseModel):
"""手动触发遗忘周期请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
group_id: str = Field(..., description="组ID即终端用户ID必填")
end_user_id: str = Field(..., description="组ID即终端用户ID必填")
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数默认100")
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数默认30天")

View File

@@ -113,8 +113,8 @@ class AppChatService:
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
@@ -263,8 +263,8 @@ class AppChatService:
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)

View File

@@ -318,8 +318,8 @@ class DraftRunService:
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
@@ -546,8 +546,8 @@ class DraftRunService:
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)

View File

@@ -396,6 +396,7 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message=message
end_user_id=group_id
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
@@ -528,7 +529,6 @@ class MemoryAgentService:
workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s")
# Extract final answer
final_answer = ""
for messages in outputs:
@@ -602,18 +602,24 @@ class MemoryAgentService:
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception as e:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id, # 确保这个变量在作用域内
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,

View File

@@ -205,8 +205,8 @@ class MemoryConfigService:
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
reflexion_range=memory_config.reflexion_range or "partial",
reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,

View File

@@ -490,17 +490,19 @@ class MemoryEmotion:
# 如果created_at是字符串格式尝试格式化
if isinstance(created_at, str):
formatted_created_at = self._format_datetime(created_at)
emotion_type = record.get('emotion_type')
emotion_intensity = record.get('emotion_intensity')
if emotion_type !=None:
length_data.append(emotion_intensity)
if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None:
# 使用(emotion_type, created_at)作为分组键
if emotion_type in {"joy", "surprise"}:
emotion_type='positive'
elif emotion_type in {"sadness", "fear", "anger"}:
emotion_type='negative'
elif emotion_type=='neutral':
emotion_type='neutral'
group_key = (emotion_type, formatted_created_at)
# 累加emotion_intensity
try:
emotion_groups[group_key] += float(emotion_intensity)

View File

@@ -209,7 +209,7 @@ class SharedChatService:
# 添加长期记忆工具
memory_flag=False
if memory==True:
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag=True
@@ -219,8 +219,8 @@ class SharedChatService:
web_tools=config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled",False)
if web_search==True:
if web_search_enable==True:
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
@@ -413,8 +413,8 @@ class SharedChatService:
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)