Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -70,24 +70,25 @@ async def trigger_forgetting_cycle(
|
||||
ApiResponse: 包含遗忘报告的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = payload.end_user_id # 从 payload 中获取 end_user_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 通过 group_id 获取关联的 config_id
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(payload.group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {payload.group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {payload.group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={payload.group_id} 获取到 config_id={config_id}")
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
@@ -97,7 +98,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
|
||||
f"group_id={payload.group_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"end_user_id={end_user_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"min_days={payload.min_days_since_access}"
|
||||
)
|
||||
|
||||
@@ -105,7 +106,7 @@ async def trigger_forgetting_cycle(
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=payload.group_id,
|
||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
|
||||
@@ -641,6 +641,7 @@ class AccessHistoryManager:
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.statement as statement,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
|
||||
@@ -54,9 +54,9 @@ async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
|
||||
Returns:
|
||||
符合反思范围的记忆数据列表。
|
||||
"""
|
||||
if REFLEXION_RANGE == "retrieval":
|
||||
if REFLEXION_RANGE == "partial":
|
||||
return await get_data(host_id)
|
||||
elif REFLEXION_RANGE == "database":
|
||||
elif REFLEXION_RANGE == "all":
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")
|
||||
|
||||
@@ -260,8 +260,7 @@ class ConditionBase(ABC):
|
||||
raise RuntimeError("Unsupported variable type")
|
||||
|
||||
def check(self, no_right=False):
|
||||
left = self.pool.get(self.left_selector.variable_selector)
|
||||
if not isinstance(left, self.type_limit):
|
||||
if not isinstance(self.left_value, self.type_limit):
|
||||
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
|
||||
if not no_right:
|
||||
right = self.resolve_right_literal_value()
|
||||
|
||||
@@ -37,7 +37,7 @@ class ParamsConfig(BaseModel):
|
||||
)
|
||||
|
||||
required: bool = Field(
|
||||
...,
|
||||
default=False,
|
||||
description="Whether the parameter is required"
|
||||
)
|
||||
|
||||
@@ -59,6 +59,6 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||
)
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
default="",
|
||||
description="User-provided supplemental prompt"
|
||||
)
|
||||
|
||||
@@ -157,9 +157,17 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", self._render_template(self.typed_config.prompt, state)),
|
||||
("user", rendered_user_prompt),
|
||||
|
||||
]
|
||||
if self.typed_config.prompt:
|
||||
messages.extend([
|
||||
("user", self._render_template(self.typed_config.prompt, state)),
|
||||
("user", rendered_user_prompt),
|
||||
])
|
||||
else:
|
||||
messages.extend([
|
||||
("user", rendered_user_prompt),
|
||||
])
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
|
||||
@@ -727,6 +727,7 @@ SET m += {
|
||||
dialog_id: summary.dialog_id,
|
||||
chunk_ids: summary.chunk_ids,
|
||||
content: summary.content,
|
||||
memory_type: summary.memory_type,
|
||||
summary_embedding: summary.summary_embedding,
|
||||
config_id: summary.config_id,
|
||||
importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END,
|
||||
|
||||
@@ -173,9 +173,10 @@ class MemoryConfigValidation(BaseModel):
|
||||
chunker_strategy: str = Field(default="RecursiveChunker", min_length=1, max_length=100)
|
||||
reflexion_enabled: bool = Field(default=False)
|
||||
reflexion_iteration_period: int = Field(default=3, ge=1, le=100)
|
||||
reflexion_range: Literal["retrieval", "all"] = Field(default="retrieval")
|
||||
reflexion_baseline: Literal["time", "fact", "time_and_fact"] = Field(default="time")
|
||||
|
||||
reflexion_range: Literal["partial", "all"] = Field(default="partial")
|
||||
reflexion_baseline: Literal["TIME", "FACT", "HYBRID"] = Field(default="TIME")
|
||||
|
||||
|
||||
llm_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
config_version: str = Field(default="2.0", min_length=1, max_length=10)
|
||||
|
||||
@@ -292,8 +292,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
||||
iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(
|
||||
"3", description="反思迭代周期,单位小时"
|
||||
)
|
||||
reflexion_range: Optional[Literal["retrieval", "database"]] = Field(
|
||||
"retrieval", description="反思范围:部分/全部"
|
||||
reflexion_range: Optional[Literal["partial", "all"]] = Field(
|
||||
"partial", description="反思范围:部分/全部"
|
||||
)
|
||||
baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field(
|
||||
"TIME", description="基线:时间/事实/时间和事实"
|
||||
@@ -409,7 +409,7 @@ class ForgettingTriggerRequest(BaseModel):
|
||||
"""手动触发遗忘周期请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
group_id: str = Field(..., description="组ID(即终端用户ID,必填)")
|
||||
end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)")
|
||||
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)")
|
||||
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)")
|
||||
|
||||
|
||||
@@ -113,8 +113,8 @@ class AppChatService:
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
@@ -263,8 +263,8 @@ class AppChatService:
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
|
||||
@@ -318,8 +318,8 @@ class DraftRunService:
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
@@ -546,8 +546,8 @@ class DraftRunService:
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
|
||||
@@ -396,6 +396,7 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
ori_message=message
|
||||
end_user_id=group_id
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
@@ -528,7 +529,6 @@ class MemoryAgentService:
|
||||
|
||||
workflow_duration = time.time() - start
|
||||
logger.info(f"Read graph workflow completed in {workflow_duration}s")
|
||||
|
||||
# Extract final answer
|
||||
final_answer = ""
|
||||
for messages in outputs:
|
||||
@@ -602,18 +602,24 @@ class MemoryAgentService:
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
if str(search_switch)!="2":
|
||||
for intermediate in intermediate_outputs:
|
||||
print(intermediate)
|
||||
intermediate_type=intermediate['type']
|
||||
if intermediate_type=="search_result":
|
||||
query=intermediate['query']
|
||||
raw_results=intermediate['raw_results']
|
||||
reranked_results=raw_results.get('reranked_results',[])
|
||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
try:
|
||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
except Exception as e:
|
||||
statements=[]
|
||||
statements=list(set(statements))
|
||||
retrieved_content.append({query:statements})
|
||||
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
|
||||
if retrieved_content==[]:
|
||||
retrieved_content=''
|
||||
if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[]
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=group_id, # 确保这个变量在作用域内
|
||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
||||
messages=ori_message,
|
||||
aimessages=final_answer,
|
||||
retrieved_content=retrieved_content,
|
||||
|
||||
@@ -205,8 +205,8 @@ class MemoryConfigService:
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
reflexion_range=memory_config.reflexion_range or "partial",
|
||||
reflexion_baseline=memory_config.baseline or "Time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
|
||||
@@ -490,17 +490,19 @@ class MemoryEmotion:
|
||||
# 如果created_at是字符串格式,尝试格式化
|
||||
if isinstance(created_at, str):
|
||||
formatted_created_at = self._format_datetime(created_at)
|
||||
|
||||
emotion_type = record.get('emotion_type')
|
||||
emotion_intensity = record.get('emotion_intensity')
|
||||
if emotion_type !=None:
|
||||
length_data.append(emotion_intensity)
|
||||
|
||||
|
||||
if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None:
|
||||
# 使用(emotion_type, created_at)作为分组键
|
||||
if emotion_type in {"joy", "surprise"}:
|
||||
emotion_type='positive'
|
||||
elif emotion_type in {"sadness", "fear", "anger"}:
|
||||
emotion_type='negative'
|
||||
elif emotion_type=='neutral':
|
||||
emotion_type='neutral'
|
||||
group_key = (emotion_type, formatted_created_at)
|
||||
|
||||
# 累加emotion_intensity
|
||||
try:
|
||||
emotion_groups[group_key] += float(emotion_intensity)
|
||||
|
||||
@@ -209,7 +209,7 @@ class SharedChatService:
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag=False
|
||||
if memory==True:
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag=True
|
||||
@@ -219,8 +219,8 @@ class SharedChatService:
|
||||
web_tools=config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled",False)
|
||||
if web_search==True:
|
||||
if web_search_enable==True:
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
@@ -413,8 +413,8 @@ class SharedChatService:
|
||||
web_tools = config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user