From 117e29fbe3f525e3b00fc90be87fdafc470e768f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 16:46:09 +0800 Subject: [PATCH 01/85] fix(memory): improve optimistic lock resilience in access history manager - Increase max_retries from 3 to 5 for concurrent conflict recovery - Add randomized exponential backoff between retries to reduce contention - Merge duplicate node accesses in batch operations to avoid self-conflicts - Support access_times parameter for merged batch access counting - Add Community node label support in atomic update content field map --- .../access_history_manager.py | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index a71c0957..cc477330 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -57,7 +57,7 @@ class AccessHistoryManager: self, connector: Neo4jConnector, actr_calculator: ACTRCalculator, - max_retries: int = 3 + max_retries: int = 5 ): """ 初始化访问历史管理器 @@ -76,7 +76,8 @@ class AccessHistoryManager: node_id: str, node_label: str, end_user_id: Optional[str] = None, - current_time: Optional[datetime] = None + current_time: Optional[datetime] = None, + access_times: int = 1 ) -> Dict[str, Any]: """ 记录节点访问并原子性更新所有相关字段 @@ -93,6 +94,7 @@ class AccessHistoryManager: node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) end_user_id: 组ID(可选,用于过滤) current_time: 当前时间(可选,默认使用系统时间) + access_times: 本次访问次数(默认1,批量合并时可能大于1) Returns: Dict[str, Any]: 更新后的节点数据,包含: @@ -134,7 +136,8 @@ class AccessHistoryManager: update_data = await self._calculate_update( node_data=node_data, current_time=current_time, - current_time_iso=current_time_iso + current_time_iso=current_time_iso, + access_times=access_times ) # 步骤3:原子性更新节点(使用事务) @@ -149,15 +152,21 @@ class AccessHistoryManager: f"成功记录访问: {node_label}[{node_id}], " f"activation={update_data['activation_value']:.4f}, " f"access_count={update_data['access_count']}" + f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" ) return updated_node except Exception as e: if attempt < self.max_retries - 1: + # 随机退避:避免并发请求同时重试再次冲突 + import random + backoff = random.uniform(0.05, 0.2) * (attempt + 1) logger.warning( - f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}" + f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," + f"{backoff:.3f}s 后重试: {str(e)}" ) + await asyncio.sleep(backoff) continue else: logger.error( @@ -179,10 +188,11 @@ class AccessHistoryManager: 批量记录多个节点的访问 为提高性能,批量更新多个节点的访问历史。 - 每个节点独立更新,失败的节点不影响其他节点。 + 对同一个节点的多次访问会先在内存中合并,只发起一次更新, + 从而避免同节点并发写入导致的乐观锁冲突。 Args: - node_ids: 节点ID列表 + node_ids: 节点ID列表(可包含重复ID) node_label: 节点标签(所有节点必须是同一类型) end_user_id: 组ID(可选) current_time: 当前时间(可选) @@ -196,25 +206,40 @@ class AccessHistoryManager: if current_time is None: current_time = datetime.now() - # PERFORMANCE FIX: Process all nodes in parallel instead of sequentially - tasks = [] + # 合并同一节点的访问次数,避免对同一节点并发写入 + access_count_map: Dict[str, int] = {} for node_id in node_ids: + access_count_map[node_id] = access_count_map.get(node_id, 0) + 1 + + merged_count = len(node_ids) - len(access_count_map) + if merged_count > 0: + logger.info( + f"批量访问合并: 原始={len(node_ids)}, " + f"去重后={len(access_count_map)}, 合并={merged_count}" + ) + + # 对去重后的节点并行发起更新 + tasks = [] + for node_id, access_times in access_count_map.items(): task = self.record_access( node_id=node_id, node_label=node_label, end_user_id=end_user_id, - current_time=current_time + current_time=current_time, + access_times=access_times ) - tasks.append(task) + tasks.append((node_id, task)) # Execute all tasks in parallel - task_results = await asyncio.gather(*tasks, return_exceptions=True) + task_results = await asyncio.gather( + *[t for _, t in tasks], return_exceptions=True + ) # Collect successful results and count failures results = [] failed_count = 0 - for node_id, result in zip(node_ids, task_results): + for (node_id, _), result in zip(tasks, task_results): if isinstance(result, Exception): failed_count += 1 logger.warning( @@ -225,7 +250,7 @@ class AccessHistoryManager: batch_duration = time.time() - batch_start logger.info( - f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " + f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, " f"失败 {failed_count}, 耗时 {batch_duration:.4f}s" ) @@ -532,7 +557,8 @@ class AccessHistoryManager: self, node_data: Dict[str, Any], current_time: datetime, - current_time_iso: str + current_time_iso: str, + access_times: int = 1 ) -> Dict[str, Any]: """ 计算更新数据 @@ -541,6 +567,7 @@ class AccessHistoryManager: node_data: 当前节点数据 current_time: 当前时间(datetime对象) current_time_iso: 当前时间(ISO格式字符串) + access_times: 本次访问次数(合并后可能大于1) Returns: Dict[str, Any]: 更新数据,包含所有需要更新的字段 @@ -551,8 +578,8 @@ class AccessHistoryManager: if importance_score is None: importance_score = 0.5 - # 追加新的访问时间 - new_access_history = access_history + [current_time_iso] + # 追加新的访问时间(合并场景下追加多条相同时间戳) + new_access_history = access_history + [current_time_iso] * access_times # 修剪访问历史(如果过长) access_history_dt = [ @@ -642,7 +669,8 @@ class AccessHistoryManager: content_field_map = { 'Statement': 'n.statement as statement', 'MemorySummary': 'n.content as content', - 'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 + 'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤 + 'Community': 'n.summary as summary' } # 显式检查节点类型,不支持的类型抛出错误 From 00a8099857beaca14cbf5ffcb322e052eaa1b690 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 16:55:53 +0800 Subject: [PATCH 02/85] changes:(api) Change the "jitter" to "tremble". --- .../forgetting_engine/access_history_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index cc477330..a5f48982 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -159,9 +159,9 @@ class AccessHistoryManager: except Exception as e: if attempt < self.max_retries - 1: - # 随机退避:避免并发请求同时重试再次冲突 + # 带抖动的指数退避:base * 2^attempt * random(0.5, 1.0) import random - backoff = random.uniform(0.05, 0.2) * (attempt + 1) + backoff = 0.05 * (2 ** attempt) * random.uniform(0.5, 1.0) logger.warning( f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," f"{backoff:.3f}s 后重试: {str(e)}" From 99862db7a05dc06e02af0179df5a6f2d48aff013 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 18:40:03 +0800 Subject: [PATCH 03/85] refactor(forgetting-engine): replace optimistic locking with APOC atomic operations in access history manager - Replace version-based optimistic locking and retry loop with apoc.atomic.add/insert for concurrent safety - Merge duplicate accesses within a batch before updating (access_count_delta) - Simplify _calculate_update to only compute on new timestamps instead of full history rebuild - Remove max_retries instance variable (kept as param for backward compat) - Trim verbose docstrings and inline comments --- .../access_history_manager.py | 420 ++++++------------ 1 file changed, 126 insertions(+), 294 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index a5f48982..e5254646 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -42,15 +42,14 @@ class AccessHistoryManager: - access_count: 访问次数 特性: - - 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚 - - 并发安全:使用乐观锁机制防止并发冲突 + - 原子性更新:使用 APOC 原子操作确保并发安全 + - 批次内合并:同一批次中对同一节点的多次访问合并为一次更新 - 一致性保证:提供一致性检查和自动修复功能 - 智能修剪:自动修剪过长的访问历史 Attributes: connector: Neo4j连接器实例 actr_calculator: ACT-R激活值计算器实例 - max_retries: 并发冲突时的最大重试次数 """ def __init__( @@ -65,12 +64,11 @@ class AccessHistoryManager: Args: connector: Neo4j连接器实例 actr_calculator: ACT-R激活值计算器实例 - max_retries: 并发冲突时的最大重试次数(默认3次) + max_retries: 已废弃,保留参数兼容性(APOC 原子操作无需重试) """ self.connector = connector self.actr_calculator = actr_calculator - self.max_retries = max_retries - + async def record_access( self, node_id: str, @@ -82,13 +80,6 @@ class AccessHistoryManager: """ 记录节点访问并原子性更新所有相关字段 - 这是核心方法,实现了: - 1. 首次访问:初始化access_history,计算初始激活值 - 2. 后续访问:追加访问历史,重新计算激活值 - 3. 历史修剪:当历史过长时自动修剪 - 4. 原子性:所有字段在单个事务中更新 - 5. 并发安全:使用乐观锁重试机制 - Args: node_id: 节点ID node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) @@ -97,17 +88,11 @@ class AccessHistoryManager: access_times: 本次访问次数(默认1,批量合并时可能大于1) Returns: - Dict[str, Any]: 更新后的节点数据,包含: - - id: 节点ID - - activation_value: 更新后的激活值 - - access_history: 更新后的访问历史 - - last_access_time: 最后访问时间 - - access_count: 访问次数 - - importance_score: 重要性分数 + Dict[str, Any]: 更新后的节点数据 Raises: ValueError: 如果节点不存在或节点标签无效 - RuntimeError: 如果重试次数耗尽仍然失败 + RuntimeError: 如果更新失败 """ if current_time is None: current_time = datetime.now() @@ -121,62 +106,48 @@ class AccessHistoryManager: f"Invalid node_label: {node_label}. Must be one of {valid_labels}" ) - # 使用乐观锁重试机制处理并发冲突 - for attempt in range(self.max_retries): - try: - # 步骤1:读取当前节点状态 - node_data = await self._fetch_node(node_id, node_label, end_user_id) - - if not node_data: - raise ValueError( - f"Node not found: {node_label} with id={node_id}" - ) - - # 步骤2:计算新的访问历史和激活值 - update_data = await self._calculate_update( - node_data=node_data, - current_time=current_time, - current_time_iso=current_time_iso, - access_times=access_times + try: + # 步骤1:读取当前节点状态 + node_data = await self._fetch_node(node_id, node_label, end_user_id) + + if not node_data: + raise ValueError( + f"Node not found: {node_label} with id={node_id}" ) - - # 步骤3:原子性更新节点(使用事务) - updated_node = await self._atomic_update( - node_id=node_id, - node_label=node_label, - update_data=update_data, - end_user_id=end_user_id - ) - - logger.info( - f"成功记录访问: {node_label}[{node_id}], " - f"activation={update_data['activation_value']:.4f}, " - f"access_count={update_data['access_count']}" - f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" - ) - - return updated_node - - except Exception as e: - if attempt < self.max_retries - 1: - # 带抖动的指数退避:base * 2^attempt * random(0.5, 1.0) - import random - backoff = 0.05 * (2 ** attempt) * random.uniform(0.5, 1.0) - logger.warning( - f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," - f"{backoff:.3f}s 后重试: {str(e)}" - ) - await asyncio.sleep(backoff) - continue - else: - logger.error( - f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], " - f"错误: {str(e)}" - ) - raise RuntimeError( - f"Failed to record access after {self.max_retries} attempts: {str(e)}" - ) - + + # 步骤2:计算新的访问历史和激活值 + update_data = await self._calculate_update( + node_data=node_data, + current_time=current_time, + current_time_iso=current_time_iso, + access_times=access_times + ) + + # 步骤3:使用 APOC 原子操作更新节点(无需重试) + updated_node = await self._atomic_update( + node_id=node_id, + node_label=node_label, + update_data=update_data, + end_user_id=end_user_id + ) + + logger.info( + f"成功记录访问: {node_label}[{node_id}], " + f"activation={update_data['activation_value']:.4f}, " + f"access_count={update_data['access_count']}" + f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" + ) + + return updated_node + + except Exception as e: + logger.error( + f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}" + ) + raise RuntimeError( + f"Failed to record access: {str(e)}" + ) from e + async def record_batch_access( self, node_ids: List[str], @@ -187,9 +158,7 @@ class AccessHistoryManager: """ 批量记录多个节点的访问 - 为提高性能,批量更新多个节点的访问历史。 - 对同一个节点的多次访问会先在内存中合并,只发起一次更新, - 从而避免同节点并发写入导致的乐观锁冲突。 + 对同一个节点的多次访问会先在内存中合并,只发起一次更新。 Args: node_ids: 节点ID列表(可包含重复ID) @@ -230,12 +199,10 @@ class AccessHistoryManager: ) tasks.append((node_id, task)) - # Execute all tasks in parallel task_results = await asyncio.gather( *[t for _, t in tasks], return_exceptions=True ) - # Collect successful results and count failures results = [] failed_count = 0 @@ -255,7 +222,7 @@ class AccessHistoryManager: ) return results - + async def check_consistency( self, node_id: str, @@ -264,22 +231,6 @@ class AccessHistoryManager: ) -> Tuple[ConsistencyCheckResult, Optional[str]]: """ 检查节点数据的一致性 - - 验证以下一致性规则: - 1. access_history[-1] == last_access_time - 2. len(access_history) == access_count - 3. 如果有访问历史,必须有激活值 - 4. 激活值必须在有效范围内 [offset, 1.0] - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - Tuple[ConsistencyCheckResult, Optional[str]]: - - 一致性检查结果枚举 - - 错误描述(如果不一致) """ node_data = await self._fetch_node(node_id, node_label, end_user_id) @@ -291,7 +242,6 @@ class AccessHistoryManager: access_count = node_data.get('access_count', 0) activation_value = node_data.get('activation_value') - # 检查1:access_history[-1] == last_access_time if access_history and last_access_time: if access_history[-1] != last_access_time: return ( @@ -300,7 +250,6 @@ class AccessHistoryManager: f"last_access_time={last_access_time}" ) - # 检查2:len(access_history) == access_count if len(access_history) != access_count: return ( ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT, @@ -308,14 +257,12 @@ class AccessHistoryManager: f"access_count={access_count}" ) - # 检查3:有访问历史必须有激活值 if access_history and activation_value is None: return ( ConsistencyCheckResult.MISSING_ACTIVATION, "Node has access_history but activation_value is None" ) - # 检查4:激活值范围 if activation_value is not None: offset = self.actr_calculator.offset if not (offset <= activation_value <= 1.0): @@ -326,30 +273,14 @@ class AccessHistoryManager: ) return ConsistencyCheckResult.CONSISTENT, None - + async def check_batch_consistency( self, node_label: str, end_user_id: Optional[str] = None, limit: int = 1000 ) -> Dict[str, Any]: - """ - 批量检查多个节点的一致性 - - Args: - node_label: 节点标签 - end_user_id: 组ID(可选) - limit: 检查的最大节点数 - - Returns: - Dict[str, Any]: 一致性检查报告,包含: - - total_checked: 检查的节点总数 - - consistent_count: 一致的节点数 - - inconsistent_count: 不一致的节点数 - - inconsistencies: 不一致节点的详细信息列表 - - consistency_rate: 一致性率(0-1) - """ - # 查询所有相关节点 + """批量检查多个节点的一致性""" query = f""" MATCH (n:{node_label}) WHERE n.access_history IS NOT NULL @@ -368,7 +299,6 @@ class AccessHistoryManager: results = await self.connector.execute_query(query, **params) node_ids = [r['id'] for r in results] - # 检查每个节点 inconsistencies = [] consistent_count = 0 @@ -407,32 +337,15 @@ class AccessHistoryManager: ) return report - + async def repair_inconsistency( self, node_id: str, node_label: str, end_user_id: Optional[str] = None ) -> bool: - """ - 自动修复节点的数据不一致问题 - - 修复策略: - 1. 如果access_history[-1] != last_access_time:使用access_history[-1] - 2. 如果len(access_history) != access_count:使用len(access_history) - 3. 如果有历史但无激活值:重新计算激活值 - 4. 如果激活值超出范围:重新计算激活值 - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - bool: 修复成功返回True,否则返回False - """ + """自动修复节点的数据不一致问题""" try: - # 检查一致性 result, message = await self.check_consistency( node_id=node_id, node_label=node_label, @@ -443,7 +356,6 @@ class AccessHistoryManager: logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]") return True - # 获取节点数据 node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") @@ -452,17 +364,13 @@ class AccessHistoryManager: access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) - # 准备修复数据 repair_data = {} - # 修复last_access_time if access_history: repair_data['last_access_time'] = access_history[-1] - # 修复access_count repair_data['access_count'] = len(access_history) - # 修复activation_value if access_history: current_time = datetime.now() last_access_dt = datetime.fromisoformat(access_history[-1]) @@ -478,7 +386,6 @@ class AccessHistoryManager: ) repair_data['activation_value'] = activation_value - # 执行修复 query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ @@ -509,26 +416,16 @@ class AccessHistoryManager: f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}" ) return False - + # ==================== 私有辅助方法 ==================== - + async def _fetch_node( self, node_id: str, node_label: str, end_user_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: - """ - 获取节点数据 - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - Optional[Dict[str, Any]]: 节点数据,如果不存在返回None - """ + """获取节点数据""" query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ @@ -552,7 +449,7 @@ class AccessHistoryManager: if results: return results[0] return None - + async def _calculate_update( self, node_data: Dict[str, Any], @@ -570,43 +467,37 @@ class AccessHistoryManager: access_times: 本次访问次数(合并后可能大于1) Returns: - Dict[str, Any]: 更新数据,包含所有需要更新的字段 + Dict[str, Any]: 更新数据 """ - access_history = node_data.get('access_history') or [] - # Handle None importance_score - default to 0.5 importance_score = node_data.get('importance_score') if importance_score is None: importance_score = 0.5 - # 追加新的访问时间(合并场景下追加多条相同时间戳) - new_access_history = access_history + [current_time_iso] * access_times + # 本次新增的时间戳 + new_timestamps = [current_time_iso] * access_times - # 修剪访问历史(如果过长) - access_history_dt = [ - datetime.fromisoformat(ts) for ts in new_access_history - ] + # 仅用本次新增的访问记录计算激活值 + new_history_dt = [current_time] * access_times trimmed_history_dt = self.actr_calculator.trim_access_history( - access_history=access_history_dt, + access_history=new_history_dt, current_time=current_time ) - trimmed_history = [ts.isoformat() for ts in trimmed_history_dt] - # 计算新的激活值 activation_value = self.actr_calculator.calculate_memory_activation( access_history=trimmed_history_dt, current_time=current_time, - last_access_time=current_time, # 最后访问时间就是当前时间 + last_access_time=current_time, importance_score=importance_score ) - # 返回所有需要更新的字段 return { 'activation_value': activation_value, - 'access_history': trimmed_history, + 'new_timestamps': new_timestamps, + 'access_count_delta': access_times, + 'access_count': len(trimmed_history_dt), 'last_access_time': current_time_iso, - 'access_count': len(trimmed_history) } - + async def _atomic_update( self, node_id: str, @@ -615,10 +506,10 @@ class AccessHistoryManager: end_user_id: Optional[str] = None ) -> Dict[str, Any]: """ - 原子性更新节点(使用乐观锁) + 原子性更新节点(使用 APOC 原子操作) - 使用Neo4j事务和版本号确保所有字段同时更新或回滚。 - 实现乐观锁机制防止并发冲突。 + 使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全, + 无需 version 字段和乐观锁,数据库层面保证原子性。 Args: node_id: 节点ID @@ -630,127 +521,68 @@ class AccessHistoryManager: Dict[str, Any]: 更新后的节点数据 Raises: - RuntimeError: 如果更新失败或发生版本冲突 + RuntimeError: 如果更新失败 """ - # 定义事务函数 - async def update_transaction(tx, node_id, node_label, update_data, end_user_id): - # 步骤1:读取当前节点并获取版本号 - read_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - """ - if end_user_id: - read_query += " WHERE n.end_user_id = $end_user_id" - read_query += """ - RETURN n.id as id, - n.version as version, - n.activation_value as activation_value, - n.access_history as access_history, - n.last_access_time as last_access_time, - n.access_count as access_count, - n.importance_score as importance_score - """ + content_field_map = { + 'Statement': 'n.statement as statement', + 'MemorySummary': 'n.content as content', + 'ExtractedEntity': 'null as content_placeholder', + 'Community': 'n.summary as summary' + } + + if node_label not in content_field_map: + raise ValueError( + f"Unsupported node_label: {node_label}. " + f"Supported labels are: {list(content_field_map.keys())}" + ) + + content_field = content_field_map[node_label] + + where_clause = "" + if end_user_id: + where_clause = " AND n.end_user_id = $end_user_id" + + query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + WHERE true{where_clause} + CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count + WITH n + CALL (n) {{ + UNWIND $new_timestamps AS ts + CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue + RETURN count(*) AS inserted + }} + SET n.activation_value = $activation_value, + n.last_access_time = $last_access_time + RETURN n.id as id, + n.activation_value as activation_value, + n.access_history as access_history, + n.last_access_time as last_access_time, + n.access_count as access_count, + n.importance_score as importance_score, + {content_field} + """ + + params = { + 'node_id': node_id, + 'access_count_delta': update_data['access_count_delta'], + 'new_timestamps': update_data['new_timestamps'], + 'activation_value': update_data['activation_value'], + 'last_access_time': update_data['last_access_time'], + } + if end_user_id: + params['end_user_id'] = end_user_id + + try: + results = await self.connector.execute_query(query, **params) - read_params = {'node_id': node_id} - if end_user_id: - read_params['end_user_id'] = end_user_id - - read_result = await tx.run(read_query, **read_params) - current_node = await read_result.single() - - if not current_node: + if not results: raise RuntimeError(f"Node not found: {node_label}[{node_id}]") - # 获取当前版本号(如果不存在则为0) - current_version = current_node.get('version', 0) or 0 - new_version = current_version + 1 - - # 步骤2:使用乐观锁更新节点 - # 根据节点类型构建完整的查询语句 - content_field_map = { - 'Statement': 'n.statement as statement', - 'MemorySummary': 'n.content as content', - 'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤 - 'Community': 'n.summary as summary' - } - - # 显式检查节点类型,不支持的类型抛出错误 - if node_label not in content_field_map: - raise ValueError( - f"Unsupported node_label: {node_label}. " - f"Supported labels are: {list(content_field_map.keys())}" - ) - - content_field = content_field_map[node_label] - - # 构建 WHERE 子句 - where_conditions = [] - if end_user_id: - where_conditions.append("n.end_user_id = $end_user_id") - - # 添加版本检查 - if current_version > 0: - where_conditions.append("n.version = $current_version") - else: - where_conditions.append("(n.version IS NULL OR n.version = 0)") - - where_clause = " AND ".join(where_conditions) if where_conditions else "true" - - # 构建完整的更新查询 - update_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - WHERE {where_clause} - SET n.activation_value = $activation_value, - n.access_history = $access_history, - n.last_access_time = $last_access_time, - n.access_count = $access_count, - n.version = $new_version - RETURN n.id as id, - n.activation_value as activation_value, - n.access_history as access_history, - n.last_access_time as last_access_time, - n.access_count as access_count, - n.importance_score as importance_score, - n.version as version, - {content_field} - """ - - update_params = { - 'node_id': node_id, - 'current_version': current_version, - 'new_version': new_version, - 'activation_value': update_data['activation_value'], - 'access_history': update_data['access_history'], - 'last_access_time': update_data['last_access_time'], - 'access_count': update_data['access_count'] - } - if end_user_id: - update_params['end_user_id'] = end_user_id - - update_result = await tx.run(update_query, **update_params) - updated_node = await update_result.single() - - if not updated_node: - raise RuntimeError( - f"Version conflict detected for {node_label}[{node_id}]. " - f"Expected version {current_version}, but node was modified by another transaction." - ) - - # 转换为字典并移除占位符字段 - result_dict = dict(updated_node) + result_dict = dict(results[0]) result_dict.pop('content_placeholder', None) return result_dict - - # 执行事务 - try: - result = await self.connector.execute_write_transaction( - update_transaction, - node_id=node_id, - node_label=node_label, - update_data=update_data, - end_user_id=end_user_id - ) - return result except Exception as e: logger.error( f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}" From cd8229f3702f211b914506a715c87647887287e3 Mon Sep 17 00:00:00 2001 From: wxy Date: Tue, 7 Apr 2026 15:57:09 +0800 Subject: [PATCH 04/85] fix(workflow): restore opening statement and citation display in shared workflows --- api/app/services/app_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 36d7e614..5e26a629 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1360,6 +1360,7 @@ class AppService: variables=cfg.get("variables", []), execution_config=cfg.get("execution_config", {}), triggers=cfg.get("triggers", []), + features=cfg.get("features", {}), is_active=True, created_at=now, updated_at=now, From e48b146e60ff726972bbd087cd4fb9ed78b017b3 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 7 Apr 2026 17:11:45 +0800 Subject: [PATCH 05/85] Revert "fix(workflow): restore opening statement and citation in shared conversations" --- api/app/services/app_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 5e26a629..36d7e614 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1360,7 +1360,6 @@ class AppService: variables=cfg.get("variables", []), execution_config=cfg.get("execution_config", {}), triggers=cfg.get("triggers", []), - features=cfg.get("features", {}), is_active=True, created_at=now, updated_at=now, From f2d7479229593ada1da568297f155282ed76c8c1 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 9 Apr 2026 11:01:56 +0800 Subject: [PATCH 06/85] feat(memory): add async user metadata extraction pipeline - Add MetadataExtractor to collect user-related statements post-dedup and extract profile/behavioral metadata via independent LLM call - Add Celery task (extract_user_metadata) routed to memory_tasks queue - Add metadata models (UserMetadata, UserMetadataProfile, etc.) - Add metadata utility functions (clean, validate, merge with _op support) - Add Jinja2 prompt template for metadata extraction (zh/en) - Fix Lucene query parameter naming: rename `q` to `query` across all Cypher queries, graph_search functions, and callers - Escape `/` in Lucene queries to prevent TokenMgrError - Add `speaker` field to ChunkNode and persist it in Neo4j - Remove unused imports (argparse, os, UUID) in search.py - Fix unnecessary db context nesting in interest distribution task --- api/app/celery_app.py | 3 + .../nodes/perceptual_retrieve_node.py | 4 +- api/app/core/memory/models/__init__.py | 12 ++ api/app/core/memory/models/graph_models.py | 2 + api/app/core/memory/models/metadata_models.py | 40 ++++ api/app/core/memory/src/search.py | 12 +- .../extraction_orchestrator.py | 33 ++- .../metadata_extractor.py | 152 ++++++++++++++ .../triplet_extraction.py | 1 - .../storage_services/search/keyword_search.py | 4 +- api/app/core/memory/utils/data/text_utils.py | 4 +- api/app/core/memory/utils/metadata_utils.py | 179 +++++++++++++++++ .../prompt/prompts/extract_triplet.jinja2 | 8 + .../prompts/extract_user_metadata.jinja2 | 74 +++++++ api/app/repositories/neo4j/cypher_queries.py | 26 +-- api/app/repositories/neo4j/graph_search.py | 49 +++-- api/app/tasks.py | 189 +++++++++++++++--- 17 files changed, 714 insertions(+), 78 deletions(-) create mode 100644 api/app/core/memory/models/metadata_models.py create mode 100644 api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py create mode 100644 api/app/core/memory/utils/metadata_utils.py create mode 100644 api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 23fd82ed..0f8a197c 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -111,6 +111,9 @@ celery_app.conf.update( # Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题) 'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'}, + # Metadata extraction → memory_tasks queue + 'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py index f248afa5..1cf5e291 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py @@ -153,7 +153,7 @@ class PerceptualSearchService: return [] try: r = await search_perceptual( - connector=connector, q=escaped, + connector=connector, query=escaped, end_user_id=self.end_user_id, limit=limit * 5, # 多查一些以提高命中率 ) @@ -178,7 +178,7 @@ class PerceptualSearchService: if not escaped.strip(): return [] r = await search_perceptual( - connector=connector, q=escaped, + connector=connector, query=escaped, end_user_id=self.end_user_id, limit=limit, ) return r.get("perceptuals", []) diff --git a/api/app/core/memory/models/__init__.py b/api/app/core/memory/models/__init__.py index 41d08908..eed8e8c4 100644 --- a/api/app/core/memory/models/__init__.py +++ b/api/app/core/memory/models/__init__.py @@ -58,6 +58,14 @@ from app.core.memory.models.triplet_models import ( TripletExtractionResponse, ) +# User metadata models +from app.core.memory.models.metadata_models import ( + UserMetadata, + UserMetadataBehavioralHints, + UserMetadataProfile, + MetadataExtractionResponse, +) + # Ontology scenario models (LLM extracted from scenarios) from app.core.memory.models.ontology_scenario_models import ( OntologyClass, @@ -124,6 +132,10 @@ __all__ = [ "Entity", "Triplet", "TripletExtractionResponse", + "UserMetadata", + "UserMetadataBehavioralHints", + "UserMetadataProfile", + "MetadataExtractionResponse", # Ontology models "OntologyClass", "OntologyExtractionResponse", diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 1b8c9d52..6e34421c 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -364,12 +364,14 @@ class ChunkNode(Node): Attributes: dialog_id: ID of the parent dialog content: The text content of the chunk + speaker: Speaker identifier ('user' or 'assistant') chunk_embedding: Optional embedding vector for the chunk sequence_number: Order of this chunk within the dialog metadata: Additional chunk metadata as key-value pairs """ dialog_id: str = Field(..., description="ID of the parent dialog") content: str = Field(..., description="The text content of the chunk") + speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") sequence_number: int = Field(..., description="Order of this chunk within the dialog") metadata: dict = Field(default_factory=dict, description="Additional chunk metadata") diff --git a/api/app/core/memory/models/metadata_models.py b/api/app/core/memory/models/metadata_models.py new file mode 100644 index 00000000..e3184879 --- /dev/null +++ b/api/app/core/memory/models/metadata_models.py @@ -0,0 +1,40 @@ +"""Models for user metadata extraction. + +Independent from triplet_models.py - these models are used by the +standalone metadata extraction pipeline (post-dedup async Celery task). +""" + +from typing import List + +from pydantic import BaseModel, ConfigDict, Field + + +class UserMetadataProfile(BaseModel): + """用户画像信息""" + model_config = ConfigDict(extra='ignore') + role: str = Field(default="", description="用户职业或角色,如 teacher, doctor, software_engineer") + domain: str = Field(default="", description="用户所在领域,如 education, healthcare, software_development") + expertise: List[str] = Field(default_factory=list, description="用户擅长的技能或工具") + interests: List[str] = Field(default_factory=list, description="用户关注的话题或领域标签") + + +class UserMetadataBehavioralHints(BaseModel): + """行为偏好""" + model_config = ConfigDict(extra='ignore') + learning_stage: str = Field(default="", description="学习阶段") + preferred_depth: str = Field(default="", description="偏好深度") + tone_preference: str = Field(default="", description="语气偏好") + + +class UserMetadata(BaseModel): + """用户元数据顶层结构""" + model_config = ConfigDict(extra='ignore') + profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile) + behavioral_hints: UserMetadataBehavioralHints = Field(default_factory=UserMetadataBehavioralHints) + knowledge_tags: List[str] = Field(default_factory=list, description="知识标签") + + +class MetadataExtractionResponse(BaseModel): + """元数据提取 LLM 响应结构""" + model_config = ConfigDict(extra='ignore') + user_metadata: UserMetadata = Field(default_factory=UserMetadata) diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index ef39a12e..4e2883d5 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -1,4 +1,3 @@ -import argparse import asyncio import json import math @@ -6,7 +5,6 @@ import os import time from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional -from uuid import UUID if TYPE_CHECKING: from app.schemas.memory_config_schema import MemoryConfig @@ -23,7 +21,7 @@ from app.core.memory.utils.config.config_utils import ( ) from app.core.memory.utils.data.text_utils import extract_plain_query from app.core.memory.utils.data.time_utils import normalize_date_safe -from app.core.memory.utils.llm.llm_utils import get_reranker_client +# from app.core.memory.utils.llm.llm_utils import get_reranker_client from app.core.models.base import RedBearModelConfig from app.db import get_db_context from app.repositories.neo4j.graph_search import ( @@ -748,11 +746,10 @@ async def run_hybrid_search( if search_type in ["keyword", "hybrid"]: # Keyword-based search logger.info("[PERF] Starting keyword search...") - keyword_start = time.time() keyword_task = asyncio.create_task( search_graph( connector=connector, - q=query_text, + query=query_text, end_user_id=end_user_id, limit=limit, include=include @@ -762,7 +759,6 @@ async def run_hybrid_search( if search_type in ["embedding", "hybrid"]: # Embedding-based search logger.info("[PERF] Starting embedding search...") - embedding_start = time.time() # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig config_load_start = time.time() @@ -904,10 +900,10 @@ async def run_hybrid_search( else: results["latency_metrics"] = latency_metrics - logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") + logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") logger.info(f"[PERF] Total search completed in {total_latency:.4f}s") logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}") - logger.info(f"[PERF] =========================================") + logger.info("[PERF] =========================================") # Sanitize results: drop large/unused fields _remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 3229674d..8f6d9853 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -311,8 +311,35 @@ class ExtractionOrchestrator: dialog_data_list, ) - # 步骤 7: 同步用户别名到数据库表(仅正式模式) + # 步骤 7: 同步用户别名到数据库表 + 触发异步元数据提取(仅正式模式) if not is_pilot_run: + # 收集用户相关 statement 并触发异步元数据提取 + try: + from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor + metadata_extractor = MetadataExtractor(llm_client=self.llm_client, language=self.language) + user_statements = metadata_extractor.collect_user_related_statements( + entity_nodes, statement_nodes, + statement_entity_edges + ) + if user_statements: + # 获取 end_user_id 和 config_id + end_user_id = dialog_data_list[0].end_user_id if dialog_data_list else None + config_id = dialog_data_list[0].config_id if dialog_data_list and hasattr(dialog_data_list[0], 'config_id') else None + if end_user_id: + from app.tasks import extract_user_metadata_task + extract_user_metadata_task.delay( + end_user_id=str(end_user_id), + statements=user_statements, + config_id=str(config_id) if config_id else None, + language=self.language, + ) + logger.info(f"已触发异步元数据提取任务,共 {len(user_statements)} 条用户相关 statement") + else: + logger.info("未找到用户相关 statement,跳过元数据提取") + except Exception as e: + logger.error(f"触发元数据提取任务失败(不影响主流程): {e}", exc_info=True) + + # 同步用户别名到数据库表 logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表") await self._update_end_user_other_name(entity_nodes, dialog_data_list) @@ -1107,6 +1134,7 @@ class ExtractionOrchestrator: end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id content=chunk.content, + speaker=getattr(chunk, 'speaker', None), chunk_embedding=chunk.chunk_embedding, sequence_number=chunk_idx, # 添加必需的 sequence_number 字段 created_at=dialog_data.created_at, @@ -1342,7 +1370,7 @@ class ExtractionOrchestrator: async def _update_end_user_other_name( self, entity_nodes: List[ExtractedEntityNode], - dialog_data_list: List[DialogData] + dialog_data_list: List[DialogData], ) -> None: """ 将本轮提取的用户别名同步到 end_user 和 end_user_info 表。 @@ -1470,7 +1498,6 @@ class ExtractionOrchestrator: end_user_id=end_user_uuid, other_name=first_alias, aliases=merged_aliases, - meta_data={} )) logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py new file mode 100644 index 00000000..5e763622 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py @@ -0,0 +1,152 @@ +""" +Metadata extractor module. + +Collects user-related statements from post-dedup graph data and +extracts user metadata via an independent LLM call. +""" + +import logging +from typing import List, Optional + +from app.core.memory.models.graph_models import ( + ExtractedEntityNode, + StatementEntityEdge, + StatementNode, +) +from app.core.memory.models.metadata_models import ( + MetadataExtractionResponse, + UserMetadata, +) + +logger = logging.getLogger(__name__) + +# Reuse the same user-entity detection logic from dedup module +_USER_NAMES = {"用户", "我", "user", "i"} +_CANONICAL_USER_TYPE = "用户" + + +def _is_user_entity(ent: ExtractedEntityNode) -> bool: + """判断实体是否为用户实体""" + name = (getattr(ent, "name", "") or "").strip().lower() + etype = (getattr(ent, "entity_type", "") or "").strip() + return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE + + +class MetadataExtractor: + """Extracts user metadata from post-dedup graph data via independent LLM call.""" + + def __init__(self, llm_client, language: str = "zh"): + self.llm_client = llm_client + self.language = language + + @staticmethod + def detect_language(statements: List[str]) -> str: + """根据 statement 文本内容检测语言。 + 如果文本中包含中文字符则返回 "zh",否则返回 "en"。 + """ + import re + combined = " ".join(statements) + if re.search(r'[\u4e00-\u9fff]', combined): + return "zh" + return "en" + + def collect_user_related_statements( + self, + entity_nodes: List[ExtractedEntityNode], + statement_nodes: List[StatementNode], + statement_entity_edges: List[StatementEntityEdge], + ) -> List[str]: + """ + 从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。 + + 筛选逻辑: + 1. 用户实体 → StatementEntityEdge → statement(直接关联) + 2. 只保留 speaker="user" 的 statement(过滤 assistant 回复的噪声) + + Returns: + 用户发言的 statement 文本列表 + """ + # Find user entity IDs + user_entity_ids = set() + for ent in entity_nodes: + if _is_user_entity(ent): + user_entity_ids.add(ent.id) + + if not user_entity_ids: + logger.debug("未找到用户实体节点,跳过 statement 收集") + return [] + + # 用户实体 → StatementEntityEdge → statement + target_stmt_ids = set() + for edge in statement_entity_edges: + if edge.target in user_entity_ids: + target_stmt_ids.add(edge.source) + + # Collect: only speaker="user" statements, preserving order + result = [] + seen = set() + total_associated = 0 + skipped_non_user = 0 + for stmt_node in statement_nodes: + if stmt_node.id in target_stmt_ids and stmt_node.id not in seen: + total_associated += 1 + speaker = getattr(stmt_node, 'speaker', None) or 'unknown' + if speaker == "user": + text = (stmt_node.statement or "").strip() + if text: + result.append(text) + else: + skipped_non_user += 1 + seen.add(stmt_node.id) + + logger.info( + f"收集到 {len(result)} 条用户发言 statement " + f"(直接关联: {total_associated}, speaker=user: {len(result)}, " + f"跳过非user: {skipped_non_user})" + ) + if total_associated > 0 and len(result) == 0: + logger.warning( + f"有 {total_associated} 条直接关联 statement 但全部被 speaker 过滤," + f"可能本次写入不包含 user 消息" + ) + return result + + async def extract_metadata(self, statements: List[str]) -> Optional[UserMetadata]: + """ + 对筛选后的 statement 列表调用 LLM 提取元数据。 + 语言根据 statement 内容自动检测,不依赖系统界面语言。 + + Returns: + UserMetadata on success, None on failure + """ + if not statements: + return None + + try: + from app.core.memory.utils.prompt.prompt_utils import prompt_env + + # 根据写入内容的语言自动检测,而非使用系统界面语言 + detected_language = self.detect_language(statements) + logger.info(f"元数据提取语言检测结果: {detected_language}") + + template = prompt_env.get_template("extract_user_metadata.jinja2") + prompt = template.render( + statements=statements, + language=detected_language, + json_schema="", + ) + + response = await self.llm_client.response_structured( + messages=[{"role": "user", "content": prompt}], + response_model=MetadataExtractionResponse, + ) + + if response and response.user_metadata: + return response.user_metadata + + logger.warning("LLM 返回的元数据为空") + return None + + except Exception as e: + logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True) + return None diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index 7fb74b82..ea355ca1 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -1,4 +1,3 @@ -import os import asyncio from typing import List, Dict, Optional diff --git a/api/app/core/memory/storage_services/search/keyword_search.py b/api/app/core/memory/storage_services/search/keyword_search.py index d2591945..2458cf30 100644 --- a/api/app/core/memory/storage_services/search/keyword_search.py +++ b/api/app/core/memory/storage_services/search/keyword_search.py @@ -5,7 +5,7 @@ 使用Neo4j的全文索引进行高效的文本匹配。 """ -from typing import List, Dict, Any, Optional +from typing import List, Optional from app.core.logging_config import get_memory_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult @@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy): # 调用底层的关键词搜索函数 results_dict = await search_graph( connector=self.connector, - q=query_text, + query=query_text, end_user_id=end_user_id, limit=limit, include=include_list diff --git a/api/app/core/memory/utils/data/text_utils.py b/api/app/core/memory/utils/data/text_utils.py index d0b10f97..eaed0940 100644 --- a/api/app/core/memory/utils/data/text_utils.py +++ b/api/app/core/memory/utils/data/text_utils.py @@ -22,7 +22,9 @@ def escape_lucene_query(query: str) -> str: s = s.replace("\r", " ").replace("\n", " ").strip() # Lucene reserved tokens/special characters - specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':'] + # NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent + # TokenMgrError when the query contains unmatched slashes. + specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/'] # Replace longer tokens first to avoid partial double-escaping for token in sorted(specials, key=len, reverse=True): s = s.replace(token, f"\\{token}") diff --git a/api/app/core/memory/utils/metadata_utils.py b/api/app/core/memory/utils/metadata_utils.py new file mode 100644 index 00000000..ccdd1686 --- /dev/null +++ b/api/app/core/memory/utils/metadata_utils.py @@ -0,0 +1,179 @@ +""" +Metadata utility functions for cleaning, validating, aggregating, and merging +user metadata extracted from conversations. +""" + +import logging +from datetime import datetime, timezone +from typing import Optional + +from app.core.memory.models.metadata_models import UserMetadata + +logger = logging.getLogger(__name__) + + +def clean_metadata(raw: dict) -> dict: + """ + Clean metadata by removing empty string values and empty array fields recursively. + Only keeps fields with actual content. If a nested dict becomes empty after cleaning, + it is removed too. + """ + cleaned = {} + for key, value in raw.items(): + if isinstance(value, dict): + nested = clean_metadata(value) + if nested: + cleaned[key] = nested + elif isinstance(value, list): + if len(value) > 0: + cleaned[key] = value + elif isinstance(value, str): + if value != "": + cleaned[key] = value + else: + cleaned[key] = value + return cleaned + +# TODO 这个函数没有调用的地方 +def validate_metadata(raw: dict) -> Optional[UserMetadata]: + """ + Validate metadata structure using the Pydantic UserMetadata model. + Returns None and logs a WARNING on validation failure. + """ + try: + return UserMetadata.model_validate(raw) + except Exception as e: + logger.warning("Metadata validation failed: %s", e) + return None + + +def merge_metadata(existing: dict, new: dict) -> dict: + """ + Merge new extracted metadata with existing database metadata. + - Scalar fields: new value overwrites old value + - Array fields: support _op marker (append/replace/remove) + - Missing top-level keys in new: preserve existing data + - Auto-update _updated_at timestamp dict with field paths and ISO timestamps + - When existing is None or {}: directly write new + _updated_at (no merge logic) + """ + now = datetime.now(timezone.utc).isoformat() + + if not existing: + # Direct write: new + _updated_at for all fields + result = dict(new) + updated_at = {} + _collect_field_paths(result, "", updated_at, now) + if updated_at: + result["_updated_at"] = updated_at + return result + + result = dict(existing) + updated_at: dict = dict(result.get("_updated_at", {})) + + for key, new_value in new.items(): + if key == "_updated_at": + continue + + old_value = result.get(key) + + if isinstance(new_value, dict) and isinstance(old_value, dict): + # Nested dict merge (e.g. profile, behavioral_hints) + _merge_nested(result, key, old_value, new_value, updated_at, now) + elif isinstance(new_value, list) or (isinstance(new_value, dict) and "_op" in new_value): + # Array field with possible _op + _merge_array_field(result, key, old_value, new_value, updated_at, now) + else: + # Scalar top-level field + if old_value != new_value: + result[key] = new_value + updated_at[key] = now + # If equal, no change needed + + result["_updated_at"] = updated_at + return result + +# TODO 考虑大函数包含小函数,因为只服务于大函数,实现代码文件的结构清楚 +def _collect_field_paths(data: dict, prefix: str, updated_at: dict, now: str) -> None: + """Collect all leaf field paths for _updated_at on direct write.""" + for key, value in data.items(): + if key == "_updated_at": + continue + path = f"{prefix}{key}" if not prefix else f"{prefix}.{key}" + if isinstance(value, dict): + _collect_field_paths(value, path, updated_at, now) + else: + updated_at[path] = now + + +def _merge_nested( + result: dict, key: str, old_dict: dict, new_dict: dict, + updated_at: dict, now: str +) -> None: + """Merge a nested dict (e.g. profile, behavioral_hints).""" + merged = dict(old_dict) + for field, new_val in new_dict.items(): + old_val = merged.get(field) + path = f"{key}.{field}" + + if isinstance(new_val, list) or (isinstance(new_val, dict) and "_op" in new_val): + _merge_array_field_inner(merged, field, old_val, new_val, updated_at, path, now) + else: + # Scalar field + if old_val != new_val: + merged[field] = new_val + updated_at[path] = now + result[key] = merged + + +def _merge_array_field( + result: dict, key: str, old_value, new_value, + updated_at: dict, now: str +) -> None: + """Merge a top-level array field with _op support.""" + _merge_array_field_inner(result, key, old_value, new_value, updated_at, key, now) + + +def _merge_array_field_inner( + container: dict, field: str, old_value, new_value, + updated_at: dict, path: str, now: str +) -> None: + """Core array merge logic with _op support.""" + # Determine op and items + if isinstance(new_value, dict) and "_op" in new_value: + op = new_value.get("_op", "append") + items = new_value.get(field, new_value.get("items", [])) + # If the dict has a key matching the field name, use it; otherwise look for list values + if not isinstance(items, list): + # Try to find the list value in the dict (excluding _op) + for k, v in new_value.items(): + if k != "_op" and isinstance(v, list): + items = v + break + else: + items = [] + elif isinstance(new_value, list): + op = "append" + items = new_value + else: + op = "append" + items = [] + + old_arr = old_value if isinstance(old_value, list) else [] + + if op == "replace": + new_arr = items + elif op == "remove": + new_arr = [x for x in old_arr if x not in items] + else: + # append (default): merge and deduplicate + seen = list(old_arr) + for item in items: + if item not in seen: + seen.append(item) + new_arr = seen + + if old_arr != new_arr: + container[field] = new_arr + updated_at[path] = now + else: + container[field] = new_arr diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 7ded48a4..1a79b482 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -406,4 +406,12 @@ Output: - **⚠️ ALIASES ORDER: preserve temporal order of appearance** - **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []** +**Output JSON structure:** +```json +{ + "triplets": [...], + "entities": [...] +} +``` + {{ json_schema }} diff --git a/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 new file mode 100644 index 00000000..9053e57d --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 @@ -0,0 +1,74 @@ +===Task=== +Extract user metadata from the following conversation statements spoken by the user. + +{% if language == "zh" %} +**"三度原则"判断标准:** +- 复用度:该信息是否会被多个功能模块使用? +- 约束度:该信息是否会影响系统行为? +- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。 + +**提取规则:** +- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息 +- 仅提取文本中明确提到的信息,不要推测 +- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象 +- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值) + +**字段说明:** +- profile.role:用户的职业或角色,如 教师、医生、后端工程师 +- profile.domain:用户所在领域,如 教育、医疗、软件开发 +- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理 +- profile.interests:用户主动表达兴趣的话题或领域标签 +- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级) +- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨) +- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨) +- knowledge_tags:用户涉及的知识领域标签 +{% else %} +**"Three-Degree Principle" criteria:** +- Reusability: Will this information be used by multiple functional modules? +- Constraint: Will this information affect system behavior? +- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information. + +**Extraction rules:** +- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user +- Only extract information explicitly mentioned in the text, do not speculate +- If no user profile information can be extracted, return an empty user_metadata object +- **Output language must match the input text language** + +**Field descriptions:** +- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer +- profile.domain: User's domain, e.g. education, healthcare, software development +- profile.expertise: User's skills or tools (general, not limited to programming) +- profile.interests: Topics or domain tags the user actively expressed interest in +- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced) +- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive) +- behavioral_hints.tone_preference: Tone preference (casual/professional/academic) +- knowledge_tags: Knowledge domain tags related to the user +{% endif %} + +===User Statements=== +{% for stmt in statements %} +- {{ stmt }} +{% endfor %} + +===Output Format=== +Return a JSON object with the following structure: +```json +{ + "user_metadata": { + "profile": { + "role": "", + "domain": "", + "expertise": [], + "interests": [] + }, + "behavioral_hints": { + "learning_stage": "", + "preferred_depth": "", + "tone_preference": "" + }, + "knowledge_tags": [] + } +} +``` + +{{ json_schema }} diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index aa246829..4b5273ac 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -23,6 +23,7 @@ SET s += { end_user_id: statement.end_user_id, stmt_type: statement.stmt_type, statement: statement.statement, + speaker: statement.speaker, emotion_intensity: statement.emotion_intensity, emotion_target: statement.emotion_target, emotion_subject: statement.emotion_subject, @@ -56,6 +57,7 @@ SET c += { expired_at: chunk.expired_at, dialog_id: chunk.dialog_id, content: chunk.content, + speaker: chunk.speaker, chunk_embedding: chunk.chunk_embedding, sequence_number: chunk.sequence_number, start_index: chunk.start_index, @@ -283,7 +285,7 @@ LIMIT $limit """ SEARCH_STATEMENTS_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score +CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) @@ -307,7 +309,7 @@ LIMIT $limit """ # 查询实体名称包含指定字符串的实体 SEARCH_ENTITIES_BY_NAME = """ -CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score +CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) @@ -337,21 +339,21 @@ LIMIT $limit """ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ -CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score +CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) WITH e, score -WITH collect({entity: e, score: score}) AS fulltextResults +With collect({entity: e, score: score}) AS fulltextResults OPTIONAL MATCH (ae:ExtractedEntity) WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id) AND ae.aliases IS NOT NULL - AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q)) + AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query)) WITH fulltextResults, collect(ae) AS aliasEntities UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score: CASE - WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 - WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9 ELSE 0.8 END }]) AS row @@ -384,7 +386,7 @@ LIMIT $limit SEARCH_CHUNKS_BY_CONTENT = """ -CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score +CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) @@ -501,7 +503,7 @@ LIMIT $limit """ SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ -CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score +CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) @@ -677,7 +679,7 @@ SET n.invalid_at = $new_invalid_at # MemorySummary keyword search using fulltext index SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score +CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id) OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) RETURN m.id AS id, @@ -1363,7 +1365,7 @@ RETURN c.community_id AS community_id # Community keyword search: matches name or summary via fulltext index SEARCH_COMMUNITIES_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score +CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) RETURN c.community_id AS id, c.name AS name, @@ -1451,7 +1453,7 @@ RETURN elementId(r) AS uuid """ SEARCH_PERCEPTUAL_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score +CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score WHERE p.end_user_id = $end_user_id RETURN p.id AS id, p.end_user_id AS end_user_id, diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 32ec4474..a191dad6 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -2,6 +2,7 @@ import asyncio import logging from typing import Any, Dict, List, Optional +from app.core.memory.utils.data.text_utils import escape_lucene_query from app.repositories.neo4j.cypher_queries import ( CHUNK_EMBEDDING_SEARCH, COMMUNITY_EMBEDDING_SEARCH, @@ -87,7 +88,7 @@ async def _update_activation_values_batch( unique_node_ids.append(node_id) if not unique_node_ids: - logger.warning(f"批量更新激活值:没有有效的节点ID") + logger.warning("批量更新激活值:没有有效的节点ID") return nodes # 记录去重信息(仅针对具有有效 ID 的节点) @@ -223,7 +224,7 @@ async def _update_search_results_activation( async def search_graph( connector: Neo4jConnector, - q: str, + query: str, end_user_id: Optional[str] = None, limit: int = 50, include: List[str] = None, @@ -234,14 +235,14 @@ async def search_graph( OPTIMIZED: Runs all queries in parallel using asyncio.gather() INTEGRATED: Updates activation values for knowledge nodes before returning results - - Statements: matches s.statement CONTAINS q - - Entities: matches e.name CONTAINS q - - Chunks: matches s.content CONTAINS q (from Statement nodes) - - Summaries: matches ms.content CONTAINS q + - Statements: matches s.statement CONTAINS query + - Entities: matches e.name CONTAINS query + - Chunks: matches s.content CONTAINS query (from Statement nodes) + - Summaries: matches ms.content CONTAINS query Args: connector: Neo4j connector - q: Query text + query: Query text for full-text search end_user_id: Optional group filter limit: Max results per category include: List of categories to search (default: all) @@ -252,6 +253,9 @@ async def search_graph( if include is None: include = ["statements", "chunks", "entities", "summaries"] + # Escape Lucene special characters to prevent query parse errors + escaped_query = escape_lucene_query(query) + # Prepare tasks for parallel execution tasks = [] task_keys = [] @@ -260,7 +264,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD, json_format=True, - q=q, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) @@ -270,7 +274,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_ENTITIES_BY_NAME_OR_ALIAS, json_format=True, - q=q, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) @@ -280,7 +284,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_CHUNKS_BY_CONTENT, json_format=True, - q=q, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) @@ -290,7 +294,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, json_format=True, - q=q, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) @@ -300,7 +304,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_COMMUNITIES_BY_KEYWORD, json_format=True, - q=q, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) @@ -482,7 +486,7 @@ async def search_graph_by_embedding( update_time = time.time() - update_start logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") else: - logger.info(f"[PERF] Skipping activation updates (only summaries)") + logger.info("[PERF] Skipping activation updates (only summaries)") return results @@ -520,7 +524,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 # 全文索引按名称检索(包含 CONTAINS 语义) rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, - q=name, + query=escape_lucene_query(name), end_user_id=end_user_id, limit=100, ) @@ -544,7 +548,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 try: rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, - q=name.lower(), + query=escape_lucene_query(name.lower()), end_user_id=end_user_id, limit=100, ) @@ -593,11 +597,12 @@ async def search_graph_by_keyword_temporal( - Returns up to 'limit' statements """ if not query_text: - logger.warning(f"query_text不能为空") + logger.warning("query_text不能为空") return {"statements": []} + escaped_query = escape_lucene_query(query_text) statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - q=query_text, + query=escaped_query, end_user_id=end_user_id, start_date=start_date, end_date=end_date, @@ -671,7 +676,7 @@ async def search_graph_by_dialog_id( - Returns up to 'limit' dialogues """ if not dialog_id: - logger.warning(f"dialog_id不能为空") + logger.warning("dialog_id不能为空") return {"dialogues": []} dialogues = await connector.execute_query( @@ -690,7 +695,7 @@ async def search_graph_by_chunk_id( limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - logger.warning(f"chunk_id不能为空") + logger.warning("chunk_id不能为空") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, @@ -968,7 +973,7 @@ async def search_graph_l_valid_at( async def search_perceptual( connector: Neo4jConnector, - q: str, + query: str, end_user_id: Optional[str] = None, limit: int = 10, ) -> Dict[str, List[Dict[str, Any]]]: @@ -979,7 +984,7 @@ async def search_perceptual( Args: connector: Neo4j connector - q: Query text + query: Query text for full-text search end_user_id: Optional user filter limit: Max results @@ -989,7 +994,7 @@ async def search_perceptual( try: perceptuals = await connector.execute_query( SEARCH_PERCEPTUAL_BY_KEYWORD, - q=q, + query=escape_lucene_query(query), end_user_id=end_user_id, limit=limit, ) diff --git a/api/app/tasks.py b/api/app/tasks.py index f918743c..4914e142 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1001,7 +1001,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): except Exception as e: print(f"\n\nError during fetch feishu: {e}") case _: # General - print(f"General: No synchronization needed\n") + print("General: No synchronization needed\n") result = f"sync knowledge '{db_knowledge.name}' processed successfully." return result @@ -1510,6 +1510,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "status": "SUCCESS", "total_num": total_num, "end_user_count": len(end_users), + "end_user_details": end_user_details, "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), }) @@ -2602,35 +2603,34 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ service = MemoryAgentService() - with get_db_context() as db: - for end_user_id in end_user_ids: - # 存在性检查:缓存有数据则跳过 - cached = await InterestMemoryCache.get_interest_distribution( + for end_user_id in end_user_ids: + # 存在性检查:缓存有数据则跳过 + cached = await InterestMemoryCache.get_interest_distribution( + end_user_id=end_user_id, + language=language, + ) + if cached is not None: + skipped += 1 + continue + + logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成") + try: + result = await service.get_interest_distribution_by_user( end_user_id=end_user_id, + limit=5, language=language, ) - if cached is not None: - skipped += 1 - continue - - logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成") - try: - result = await service.get_interest_distribution_by_user( - end_user_id=end_user_id, - limit=5, - language=language, - ) - await InterestMemoryCache.set_interest_distribution( - end_user_id=end_user_id, - language=language, - data=result, - expire=INTEREST_CACHE_EXPIRE, - ) - initialized += 1 - logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功") - except Exception as e: - failed += 1 - logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}") + await InterestMemoryCache.set_interest_distribution( + end_user_id=end_user_id, + language=language, + data=result, + expire=INTEREST_CACHE_EXPIRE, + ) + initialized += 1 + logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功") + except Exception as e: + failed += 1 + logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}") logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}") return { @@ -2914,4 +2914,139 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace } +# ─── User Metadata Extraction Task ─────────────────────────────────────────── + +@celery_app.task( + bind=True, + name='app.tasks.extract_user_metadata', + ignore_result=False, + max_retries=0, + acks_late=True, + time_limit=300, + soft_time_limit=240, +) +def extract_user_metadata_task( + self, + end_user_id: str, + statements: List[str], + config_id: Optional[str] = None, + language: str = "zh", +) -> Dict[str, Any]: + """异步提取用户元数据并写入数据库。 + + 在去重消歧完成后由编排器触发,使用独立 LLM 调用提取元数据。 + LLM 配置优先使用 config_id 对应的应用配置,失败时回退到工作空间默认配置。 + + Args: + end_user_id: 终端用户 ID + statements: 用户相关的 statement 文本列表 + config_id: 应用配置 ID(可选) + language: 语言类型 ("zh" 中文, "en" 英文) + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + logger.info( + f"[CELERY METADATA] Starting metadata extraction - end_user_id={end_user_id}, " + f"statements_count={len(statements)}, config_id={config_id}, language={language}" + ) + + async def _run() -> Dict[str, Any]: + from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor + from app.core.memory.utils.metadata_utils import clean_metadata, merge_metadata, validate_metadata + from app.repositories.end_user_info_repository import EndUserInfoRepository + from app.repositories.end_user_repository import EndUserRepository + from app.services.memory_config_service import MemoryConfigService + + # 1. 获取 LLM 配置(应用配置 → 工作空间配置兜底)并创建 LLM client + with get_db_context() as db: + end_user_uuid = uuid.UUID(end_user_id) + + # 获取 workspace_id from end_user + end_user = EndUserRepository(db).get_by_id(end_user_uuid) + if not end_user: + return {"status": "FAILURE", "error": f"End user not found: {end_user_id}"} + + workspace_id = end_user.workspace_id + + config_service = MemoryConfigService(db) + memory_config = config_service.get_config_with_fallback( + memory_config_id=uuid.UUID(config_id) if config_id else None, + workspace_id=workspace_id, + ) + if not memory_config: + return {"status": "FAILURE", "error": "No LLM config available (app + workspace fallback failed)"} + + # 2. 创建 LLM client + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + factory = MemoryClientFactory(db) + if not memory_config.llm_id: + return {"status": "FAILURE", "error": "Memory config has no LLM model configured"} + llm_client = factory.get_llm_client(memory_config.llm_id) + + # 3. 提取元数据 + extractor = MetadataExtractor(llm_client=llm_client, language=language) + user_metadata = await extractor.extract_metadata(statements) + + if not user_metadata: + logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}") + return {"status": "SUCCESS", "result": "no_metadata_extracted"} + + # 4. 清洗、校验、合并、写入 + raw_dict = user_metadata.model_dump() + cleaned = clean_metadata(raw_dict) + if not cleaned: + logger.info(f"[CELERY METADATA] Cleaned metadata is empty for end_user_id={end_user_id}") + return {"status": "SUCCESS", "result": "empty_after_cleaning"} + + validated = validate_metadata(cleaned) + if not validated: + return {"status": "FAILURE", "error": "Metadata validation failed after cleaning"} + + with get_db_context() as db: + end_user_uuid = uuid.UUID(end_user_id) + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + + if info: + existing_meta = info.meta_data if info.meta_data else {} + info.meta_data = merge_metadata(existing_meta, cleaned) + logger.info(f"[CELERY METADATA] Updated metadata for end_user_id={end_user_id}") + else: + # No end_user_info record yet - metadata will be written when alias sync creates it, + # or we create a minimal record here + logger.info( + f"[CELERY METADATA] No end_user_info record for end_user_id={end_user_id}, " + f"skipping metadata write (will be created by alias sync)" + ) + return {"status": "SUCCESS", "result": "no_info_record"} + + db.commit() + + return {"status": "SUCCESS", "result": "metadata_written"} + + loop = None + try: + loop = set_asyncio_event_loop() + result = loop.run_until_complete(_run()) + elapsed = time.time() - start_time + result["elapsed_time"] = elapsed + result["task_id"] = self.request.id + logger.info(f"[CELERY METADATA] Task completed - elapsed={elapsed:.2f}s, result={result.get('result')}") + return result + + except Exception as e: + elapsed = time.time() - start_time + logger.error(f"[CELERY METADATA] Task failed - elapsed={elapsed:.2f}s, error={e}", exc_info=True) + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed, + "task_id": self.request.id, + } + finally: + if loop: + _shutdown_loop_gracefully(loop) + + # unused task \ No newline at end of file From e0546e01ef3422afb88e6ee8c34db06948dd983a Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 9 Apr 2026 15:10:29 +0800 Subject: [PATCH 07/85] refactor(memory): delegate metadata merging to LLM instead of code-based merge - Remove merge_metadata and its helper functions from metadata_utils.py - Pass existing_metadata to MetadataExtractor.extract_metadata() as LLM context - Add merge instructions to extract_user_metadata.jinja2 prompt (zh/en) - Update Celery task to read existing metadata before extraction and overwrite - Simplify field descriptions in UserMetadataProfile model - Add _update_timestamps helper to track changed fields --- api/app/core/memory/models/metadata_models.py | 4 +- .../metadata_extractor.py | 8 +- api/app/core/memory/utils/metadata_utils.py | 138 +----------------- .../prompts/extract_user_metadata.jinja2 | 27 ++++ api/app/tasks.py | 58 ++++++-- 5 files changed, 87 insertions(+), 148 deletions(-) diff --git a/api/app/core/memory/models/metadata_models.py b/api/app/core/memory/models/metadata_models.py index e3184879..a5c70ec6 100644 --- a/api/app/core/memory/models/metadata_models.py +++ b/api/app/core/memory/models/metadata_models.py @@ -12,8 +12,8 @@ from pydantic import BaseModel, ConfigDict, Field class UserMetadataProfile(BaseModel): """用户画像信息""" model_config = ConfigDict(extra='ignore') - role: str = Field(default="", description="用户职业或角色,如 teacher, doctor, software_engineer") - domain: str = Field(default="", description="用户所在领域,如 education, healthcare, software_development") + role: str = Field(default="", description="用户职业或角色") + domain: str = Field(default="", description="用户所在领域") expertise: List[str] = Field(default_factory=list, description="用户擅长的技能或工具") interests: List[str] = Field(default_factory=list, description="用户关注的话题或领域标签") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py index 5e763622..af3331b9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py @@ -111,10 +111,15 @@ class MetadataExtractor: ) return result - async def extract_metadata(self, statements: List[str]) -> Optional[UserMetadata]: + async def extract_metadata(self, statements: List[str], existing_metadata: Optional[dict] = None) -> Optional[UserMetadata]: """ 对筛选后的 statement 列表调用 LLM 提取元数据。 语言根据 statement 内容自动检测,不依赖系统界面语言。 + 传入已有元数据作为上下文,让 LLM 能判断 replace/remove 操作。 + + Args: + statements: 用户发言的 statement 文本列表 + existing_metadata: 数据库已有的元数据(可选),用于 LLM 对比判断变更 Returns: UserMetadata on success, None on failure @@ -133,6 +138,7 @@ class MetadataExtractor: prompt = template.render( statements=statements, language=detected_language, + existing_metadata=existing_metadata, json_schema="", ) diff --git a/api/app/core/memory/utils/metadata_utils.py b/api/app/core/memory/utils/metadata_utils.py index ccdd1686..69bd8edf 100644 --- a/api/app/core/memory/utils/metadata_utils.py +++ b/api/app/core/memory/utils/metadata_utils.py @@ -1,10 +1,8 @@ """ -Metadata utility functions for cleaning, validating, aggregating, and merging -user metadata extracted from conversations. +Metadata utility functions for cleaning and validating user metadata. """ import logging -from datetime import datetime, timezone from typing import Optional from app.core.memory.models.metadata_models import UserMetadata @@ -34,7 +32,7 @@ def clean_metadata(raw: dict) -> dict: cleaned[key] = value return cleaned -# TODO 这个函数没有调用的地方 + def validate_metadata(raw: dict) -> Optional[UserMetadata]: """ Validate metadata structure using the Pydantic UserMetadata model. @@ -45,135 +43,3 @@ def validate_metadata(raw: dict) -> Optional[UserMetadata]: except Exception as e: logger.warning("Metadata validation failed: %s", e) return None - - -def merge_metadata(existing: dict, new: dict) -> dict: - """ - Merge new extracted metadata with existing database metadata. - - Scalar fields: new value overwrites old value - - Array fields: support _op marker (append/replace/remove) - - Missing top-level keys in new: preserve existing data - - Auto-update _updated_at timestamp dict with field paths and ISO timestamps - - When existing is None or {}: directly write new + _updated_at (no merge logic) - """ - now = datetime.now(timezone.utc).isoformat() - - if not existing: - # Direct write: new + _updated_at for all fields - result = dict(new) - updated_at = {} - _collect_field_paths(result, "", updated_at, now) - if updated_at: - result["_updated_at"] = updated_at - return result - - result = dict(existing) - updated_at: dict = dict(result.get("_updated_at", {})) - - for key, new_value in new.items(): - if key == "_updated_at": - continue - - old_value = result.get(key) - - if isinstance(new_value, dict) and isinstance(old_value, dict): - # Nested dict merge (e.g. profile, behavioral_hints) - _merge_nested(result, key, old_value, new_value, updated_at, now) - elif isinstance(new_value, list) or (isinstance(new_value, dict) and "_op" in new_value): - # Array field with possible _op - _merge_array_field(result, key, old_value, new_value, updated_at, now) - else: - # Scalar top-level field - if old_value != new_value: - result[key] = new_value - updated_at[key] = now - # If equal, no change needed - - result["_updated_at"] = updated_at - return result - -# TODO 考虑大函数包含小函数,因为只服务于大函数,实现代码文件的结构清楚 -def _collect_field_paths(data: dict, prefix: str, updated_at: dict, now: str) -> None: - """Collect all leaf field paths for _updated_at on direct write.""" - for key, value in data.items(): - if key == "_updated_at": - continue - path = f"{prefix}{key}" if not prefix else f"{prefix}.{key}" - if isinstance(value, dict): - _collect_field_paths(value, path, updated_at, now) - else: - updated_at[path] = now - - -def _merge_nested( - result: dict, key: str, old_dict: dict, new_dict: dict, - updated_at: dict, now: str -) -> None: - """Merge a nested dict (e.g. profile, behavioral_hints).""" - merged = dict(old_dict) - for field, new_val in new_dict.items(): - old_val = merged.get(field) - path = f"{key}.{field}" - - if isinstance(new_val, list) or (isinstance(new_val, dict) and "_op" in new_val): - _merge_array_field_inner(merged, field, old_val, new_val, updated_at, path, now) - else: - # Scalar field - if old_val != new_val: - merged[field] = new_val - updated_at[path] = now - result[key] = merged - - -def _merge_array_field( - result: dict, key: str, old_value, new_value, - updated_at: dict, now: str -) -> None: - """Merge a top-level array field with _op support.""" - _merge_array_field_inner(result, key, old_value, new_value, updated_at, key, now) - - -def _merge_array_field_inner( - container: dict, field: str, old_value, new_value, - updated_at: dict, path: str, now: str -) -> None: - """Core array merge logic with _op support.""" - # Determine op and items - if isinstance(new_value, dict) and "_op" in new_value: - op = new_value.get("_op", "append") - items = new_value.get(field, new_value.get("items", [])) - # If the dict has a key matching the field name, use it; otherwise look for list values - if not isinstance(items, list): - # Try to find the list value in the dict (excluding _op) - for k, v in new_value.items(): - if k != "_op" and isinstance(v, list): - items = v - break - else: - items = [] - elif isinstance(new_value, list): - op = "append" - items = new_value - else: - op = "append" - items = [] - - old_arr = old_value if isinstance(old_value, list) else [] - - if op == "replace": - new_arr = items - elif op == "remove": - new_arr = [x for x in old_arr if x not in items] - else: - # append (default): merge and deduplicate - seen = list(old_arr) - for item in items: - if item not in seen: - seen.append(item) - new_arr = seen - - if old_arr != new_arr: - container[field] = new_arr - updated_at[path] = now - else: - container[field] = new_arr diff --git a/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 index 9053e57d..c280e5f6 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 @@ -13,6 +13,16 @@ Extract user metadata from the following conversation statements spoken by the u - 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象 - **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值) +{% if existing_metadata %} +**重要:合并已有元数据** +下方提供了数据库中已有的用户元数据。请结合用户最新发言,输出**合并后的完整元数据**: +- 如果用户明确否定了已有信息(如"我不再教高中物理了"),在输出中**移除**该信息 +- 如果用户提到了新信息,**添加**到对应字段中 +- 如果已有信息未被用户否定,**保留**在输出中 +- 标量字段(如 role、domain):如果用户提到了新值,用新值替换;否则保留已有值 +- 最终输出应该是完整的、合并后的元数据,不是增量 +{% endif %} + **字段说明:** - profile.role:用户的职业或角色,如 教师、医生、后端工程师 - profile.domain:用户所在领域,如 教育、医疗、软件开发 @@ -34,6 +44,16 @@ Extract user metadata from the following conversation statements spoken by the u - If no user profile information can be extracted, return an empty user_metadata object - **Output language must match the input text language** +{% if existing_metadata %} +**Important: Merge with existing metadata** +Existing user metadata from the database is provided below. Combine with the user's latest statements to output the **complete merged metadata**: +- If the user explicitly negates existing info (e.g. "I no longer teach high school physics"), **remove** it from output +- If the user mentions new info, **add** it to the corresponding field +- If existing info is not negated by the user, **keep** it in the output +- Scalar fields (e.g. role, domain): replace with new value if user mentions one; otherwise keep existing +- The final output should be the complete, merged metadata — not an incremental update +{% endif %} + **Field descriptions:** - profile.role: User's occupation or role, e.g. teacher, doctor, software engineer - profile.domain: User's domain, e.g. education, healthcare, software development @@ -50,6 +70,13 @@ Extract user metadata from the following conversation statements spoken by the u - {{ stmt }} {% endfor %} +{% if existing_metadata %} +===Existing User Metadata=== +```json +{{ existing_metadata | tojson }} +``` +{% endif %} + ===Output Format=== Return a JSON object with the following structure: ```json diff --git a/api/app/tasks.py b/api/app/tasks.py index 4914e142..3eb1a52c 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,4 +1,5 @@ import asyncio +import json import os import re import shutil @@ -2916,6 +2917,20 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace # ─── User Metadata Extraction Task ─────────────────────────────────────────── + +def _update_timestamps(existing: dict, new: dict, updated_at: dict, now: str, prefix: str = "") -> None: + """对比新旧元数据,更新变更字段的 _updated_at 时间戳。""" + for key, new_val in new.items(): + if key == "_updated_at": + continue + path = f"{prefix}.{key}" if prefix else key + old_val = existing.get(key) + + if isinstance(new_val, dict) and isinstance(old_val, dict): + _update_timestamps(old_val, new_val, updated_at, now, prefix=path) + elif old_val != new_val: + updated_at[path] = now + @celery_app.task( bind=True, name='app.tasks.extract_user_metadata', @@ -2954,7 +2969,7 @@ def extract_user_metadata_task( async def _run() -> Dict[str, Any]: from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor - from app.core.memory.utils.metadata_utils import clean_metadata, merge_metadata, validate_metadata + from app.core.memory.utils.metadata_utils import clean_metadata, validate_metadata from app.repositories.end_user_info_repository import EndUserInfoRepository from app.repositories.end_user_repository import EndUserRepository from app.services.memory_config_service import MemoryConfigService @@ -2985,36 +3000,61 @@ def extract_user_metadata_task( return {"status": "FAILURE", "error": "Memory config has no LLM model configured"} llm_client = factory.get_llm_client(memory_config.llm_id) - # 3. 提取元数据 + # 2.5 读取已有元数据,传给 extractor 作为上下文 + existing_metadata = None + try: + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + if info and info.meta_data: + existing_metadata = info.meta_data + logger.info("[CELERY METADATA] 已读取数据库已有元数据作为 LLM 上下文") + except Exception as e: + logger.warning(f"[CELERY METADATA] 读取已有元数据失败(继续无上下文提取): {e}") + + # 3. 提取元数据(传入已有元数据作为上下文) extractor = MetadataExtractor(llm_client=llm_client, language=language) - user_metadata = await extractor.extract_metadata(statements) + user_metadata = await extractor.extract_metadata(statements, existing_metadata=existing_metadata) if not user_metadata: logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}") return {"status": "SUCCESS", "result": "no_metadata_extracted"} - # 4. 清洗、校验、合并、写入 - raw_dict = user_metadata.model_dump() + # 4. 清洗、校验、覆盖写入 + raw_dict = user_metadata.model_dump(exclude_none=True) + logger.info(f"[CELERY METADATA] LLM 输出完整元数据: {json.dumps(raw_dict, ensure_ascii=False)}") + cleaned = clean_metadata(raw_dict) if not cleaned: logger.info(f"[CELERY METADATA] Cleaned metadata is empty for end_user_id={end_user_id}") return {"status": "SUCCESS", "result": "empty_after_cleaning"} + logger.info(f"[CELERY METADATA] 清洗后元数据: {json.dumps(cleaned, ensure_ascii=False)}") + validated = validate_metadata(cleaned) if not validated: return {"status": "FAILURE", "error": "Metadata validation failed after cleaning"} + # 直接覆盖写入(LLM 已完成语义合并,输出的是完整结果) + # 保留 _updated_at 时间戳追踪 + from datetime import datetime as dt, timezone as tz + now = dt.now(tz.utc).isoformat() + with get_db_context() as db: end_user_uuid = uuid.UUID(end_user_id) info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) if info: existing_meta = info.meta_data if info.meta_data else {} - info.meta_data = merge_metadata(existing_meta, cleaned) - logger.info(f"[CELERY METADATA] Updated metadata for end_user_id={end_user_id}") + logger.info(f"[CELERY METADATA] 数据库已有元数据: {json.dumps(existing_meta, ensure_ascii=False)}") + + # 保留已有的 _updated_at,更新变更字段的时间戳 + updated_at = dict(existing_meta.get("_updated_at", {})) + _update_timestamps(existing_meta, cleaned, updated_at, now) + + final = dict(cleaned) + final["_updated_at"] = updated_at + info.meta_data = final + logger.info(f"[CELERY METADATA] 覆盖写入元数据: {json.dumps(final, ensure_ascii=False)}") else: - # No end_user_info record yet - metadata will be written when alias sync creates it, - # or we create a minimal record here logger.info( f"[CELERY METADATA] No end_user_info record for end_user_id={end_user_id}, " f"skipping metadata write (will be created by alias sync)" From 9f9ac69f9768505c26ec1da6eefc7e3639545804 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 9 Apr 2026 17:38:35 +0800 Subject: [PATCH 08/85] feat(web): add OpenClawTool --- web/src/i18n/en.ts | 13 ++++++++ web/src/i18n/zh.ts | 15 ++++++++++ web/src/views/ToolManagement/Inner.tsx | 6 ++-- web/src/views/ToolManagement/constant.ts | 38 ++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 3 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index bf1a3ce8..d54f0d25 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -2113,6 +2113,19 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re return_text_position_enable: 'Return Text Position Info', return_text_position_enable_desc: 'Whether to return coordinate positions of recognized text', + OpenClawTool_desc: 'OpenClaw Remote Agent', + OpenClawTool_features: 'OpenClaw Remote Agent — 3D Printing and Device Management', + OpenClawTool_config_desc: 'Configure OpenClaw Gateway connection. Server URL and API Key are required.', + OpenClawTool_server_url_desc: 'OpenClaw Gateway server URL, e.g. http://xxx.xxx.xxx.xx:xxx', + OpenClawTool_api_key_desc: 'OpenClaw API Key, created in OpenClaw admin console', + OpenClawTool_agent_id_desc: 'Target Agent ID, defaults to main, usually no need to change', + OpenClawTool_enable: 'Enable OpenClaw', + agent_id: 'Agent ID', + '3dPrinting': '3D Printing', + deviceManagement: 'Device Management', + multimodalInteraction: 'Multimodal Interaction', + remoteAgent: 'Remote Agent', + addCustom: 'Add Custom Tool', editCustom: 'Edit Custom Tool', schema: 'Schema', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index a774ec55..5b46cb48 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -2109,6 +2109,21 @@ export const zh = { return_text_position_enable: '返回文本位置信息', return_text_position_enable_desc: '是否返回识别文字的坐标位置', + OpenClawTool_desc: 'OpenClaw远程Agent', + OpenClawTool_features: 'OpenClaw远程Agent —3D打印控制、设备管理等', + OpenClawTool_config_desc: '配置OpenClaw Gateway连接信息,需要提供服务地址和API Key。', + OpenClawTool_server_url_desc: 'OpenClaw Gateway 服务地址,如 http://xxx.xxx.xxx.xx:xxx', + OpenClawTool_api_key_desc: 'OpenClaw API Key,在 OpenClaw 管理后台创建', + OpenClawTool_agent_id_desc: '目标 Agent ID,默认为 main,通常无需修改', + OpenClawTool_enable: '启用 OpenClaw', + agent_id: 'Agent ID', + '3dPrinting': '3D 打印', + deviceManagement: '设备管理', + multimodalInteraction: '多模态交互', + remoteAgent: '远程 Agent', + + + addCustom: '添加自定义工具', editCustom: '编辑自定义工具', schema: 'Schema', diff --git a/web/src/views/ToolManagement/Inner.tsx b/web/src/views/ToolManagement/Inner.tsx index b88428b0..67c3a6f5 100644 --- a/web/src/views/ToolManagement/Inner.tsx +++ b/web/src/views/ToolManagement/Inner.tsx @@ -101,13 +101,13 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode; keyword?: s {InnerConfigData[item.config_data.tool_class].features?.slice(0, 2).map((type, i) => ( -
{type}
+
{t(`tool.${type}`)}
))}
{InnerConfigData[item.config_data.tool_class].features.length > 2 && ( {InnerConfigData[item.config_data.tool_class].features?.slice(2, InnerConfigData[item.config_data.tool_class].features.length).map((type, i) => ( -
{type}
+
{t(`tool.${type}`)}
))}
} color="white" placement="bottom" @@ -135,7 +135,7 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode; keyword?: s {InnerConfigData[item.config_data.tool_class].eg} : -
{t('configStatus')}
+
{t('tool.configStatus')}
{t(`tool.${item.status}_desc`)} } diff --git a/web/src/views/ToolManagement/constant.ts b/web/src/views/ToolManagement/constant.ts index 6763a140..5641ed4d 100644 --- a/web/src/views/ToolManagement/constant.ts +++ b/web/src/views/ToolManagement/constant.ts @@ -186,5 +186,43 @@ export const InnerConfigData: Record = { 'multilingualSupport', 'highPrecisionRecognition' ], + }, + OpenClawTool: { + link: 'https://openclaw.ai/', + config: { + server_url: { + name: ['config', 'parameters', 'server_url'], + type: 'input', + desc: 'OpenClawTool_server_url_desc', + rules: [ + { required: true, message: 'common.pleaseEnter' } + ] + }, + api_key: { + name: ['config', 'parameters', 'api_key'], + type: 'input', + desc: 'OpenClawTool_api_key_desc', + rules: [ + { required: true, message: 'common.pleaseEnter' } + ] + }, + agent_id: { + name: ['config', 'parameters', 'agent_id'], + type: 'input', + desc: 'OpenClawTool_agent_id_desc', + defaultValue: 'main', + }, + OpenClawTool_enable: { + name: ['config', 'is_enabled'], + type: 'checkbox', + defaultValue: true, + }, + }, + features: [ + '3dPrinting', + 'deviceManagement', + 'multimodalInteraction', + 'remoteAgent' + ], } } \ No newline at end of file From 33a1c178ff0225af2050999a9a7b09f0c7286b37 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 9 Apr 2026 17:45:42 +0800 Subject: [PATCH 09/85] fix(web): if-else node case show --- .../Workflow/components/Nodes/ConditionNode.tsx | 4 ++-- web/src/views/Workflow/hooks/useWorkflowGraph.ts | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx index 79e8352c..996ae5dd 100644 --- a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx +++ b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx @@ -14,7 +14,7 @@ const caculateIsSet = (item: any, type: string) => { case 'cases': { if (!item.left) return false if (['not_empty', 'empty'].includes(item.operator)) return true - return !!item.left && (!!item.right || typeof item.right === 'boolean') + return !!item.left && (!!item.right || typeof item.right === 'boolean' || typeof item.right === 'number') } } } @@ -22,7 +22,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => { const data = node?.getData() || {}; const { t } = useTranslation() const graphRef = useRef(node?.model?.graph) - const variableList = useVariableList(node ?? null, graphRef, []) + const variableList = useVariableList(node ?? null, graphRef, data.chatVariables ?? []) const getLocaleField = (field: string, filedType: string) => { const key = filedType === 'boolean' ? `workflow.config.if-else..boolean.${field}` : filedType === 'number' ? `workflow.config.if-else.num.${field}` : `workflow.config.if-else.${field}` diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 45400362..19e8662b 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -106,6 +106,16 @@ export const useWorkflowGraph = ({ const [chatVariables, setChatVariables] = useState([]) const featuresRef = useRef(undefined) + useEffect(() => { + if (!graphRef.current) return + graphRef.current.getNodes().forEach(node => { + const data = node.getData() + if (data?.type === 'if-else' || data?.type === 'question-classifier') { + node.setData({ ...data, chatVariables }, { silent: true }) + } + }) + }, [chatVariables]) + useEffect(() => { getConfig() }, [id]) @@ -211,7 +221,7 @@ export const useWorkflowGraph = ({ id, type, name, - data: { ...node, ...nodeLibraryConfig}, + data: { ...node, ...nodeLibraryConfig, ...((type === 'if-else' || type === 'question-classifier') ? { chatVariables } : {}) }, ...position, } From e298b38de95552dfe8175176ff345083b0600521 Mon Sep 17 00:00:00 2001 From: miao <1468212639@qq.com> Date: Tue, 7 Apr 2026 10:35:14 +0800 Subject: [PATCH 10/85] feat(tools): add OpenClaw remote agent tool integration - Detect x-openclaw flag in OpenAPI schema and init dedicated config - Implement multimodal input/output (image download, compress, base64) - Add OpenClaw connection test and status validation in tool service - Fix auth_config token check to support both api_key and bearer_token - Inject runtime context (user_id, conversation_id, files) in chat services --- api/app/core/tools/custom/base.py | 349 +++++++++++++++++++++++++- api/app/services/app_chat_service.py | 24 +- api/app/services/draft_run_service.py | 26 +- api/app/services/tool_service.py | 74 ++++++ 4 files changed, 467 insertions(+), 6 deletions(-) diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index 3dfe4c93..3f3daad7 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -30,9 +30,60 @@ class CustomTool(BaseTool): self.auth_config = config.get("auth_config", {}) self.base_url = config.get("base_url", "") self.timeout = config.get("timeout", 30) - - # 解析schema - self._parsed_operations = self._parse_openapi_schema() + + #===========OpenClaw特殊判断(取到OpenClaw特殊配置)========== + schema = self.schema_content + if isinstance(schema, str): + try: + schema = json.loads(schema) + self.schema_content = schema + except json.JSONDecodeError: + schema = {} + + info = schema.get("info", {}) if isinstance(schema, dict) else {} + self._is_openclaw = info.get("x-openclaw", False) + + if self._is_openclaw: + # 从扩展字段读取 OpenClaw 配置 + self._openclaw_agent_id = info.get("x-openclaw-agent-id", "main") + self._openclaw_model = info.get("x-openclaw-default-model", "openclaw") + self._openclaw_session_strategy = info.get( + "x-openclaw-session-strategy", "by_user") + self._openclaw_timeout = info.get("x-openclaw-timeout", 60) + self._openclaw_input_mode = info.get("x-openclaw-input-mode", "text") + self._openclaw_output_mode = info.get("x-openclaw-output-mode", "text") + + # 从 servers 读取 base_url + servers = schema.get("servers", []) + if servers: + self.base_url = servers[0].get("url", "") + + # 从 auth_config 读取 token(兼容 api_key 和 bearer_token 两种认证方式) + self._openclaw_token = ( + self.auth_config.get("api_key") # api_key 认证方式 + or self.auth_config.get("token") # bearer_token 认证方式 + or "" + ) + + # 覆盖 timeout + self.timeout = self._openclaw_timeout + + # 运行时上下文(后续注入) + self._user_id = "anonymous" + self._conversation_id = None + self._uploaded_files = [] # 新增:用户上传的文件 + + # 跳过 Schema 解析 + self._parsed_operations = {} + + logger.info( + f"检测到 OpenClaw 工具: agent_id={self._openclaw_agent_id}, " + f"base_url={self.base_url}, " + f"input_mode={self._openclaw_input_mode}, " + f"output_mode={self._openclaw_output_mode}") + else: + # 解析schema + self._parsed_operations = self._parse_openapi_schema() @property def name(self) -> str: @@ -58,6 +109,31 @@ class CustomTool(BaseTool): @property def parameters(self) -> List[ToolParameter]: """工具参数定义""" + # ========== OpenClaw 特判 根据输入模式解析是否需要image_url ========== + if self._is_openclaw: + params = [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw Agent 的文本请求内容", + required=True + ) + ] + # 多模态输入模式下,增加 image_url 参数 + if self._openclaw_input_mode == "multimodal": + params.append(ToolParameter( + name="image_url", + type=ParameterType.STRING, + description=( + "可选,附带的图片URL或base64 data URI" + "(如 data:image/png;base64,...)。" + "传入后 Agent 可以理解图片内容。" + ), + required=False + )) + return params + # ========== 特判结束 ========== + params = [] # 添加操作选择参数 @@ -90,6 +166,10 @@ class CustomTool(BaseTool): async def execute(self, **kwargs) -> ToolResult: """执行自定义工具""" + # ========== OpenClaw 特判 ========== + if self._is_openclaw: + return await self._execute_openclaw(**kwargs) + # ========== 特判结束 ========== start_time = time.time() try: @@ -130,6 +210,269 @@ class CustomTool(BaseTool): execution_time=execution_time ) + #=============openclaw执行函数开始=============== + async def _execute_openclaw(self, **kwargs) -> ToolResult: + """OpenClaw 专属执行逻辑(支持多模态输入)""" + start_time = time.time() + try: + message = kwargs.get("message", "") + # 从用户实际上传的文件中提取图片 URL + image_url = None + if self._uploaded_files: + for f in self._uploaded_files: + if f.get("type") == "image": + source = f.get("source", {}) + if source.get("type") == "base64": + media_type = source.get("media_type", "image/jpeg") + data = source.get("data", "") + image_url = f"data:{media_type};base64,{data}" + elif f.get("image"): + # DashScope 格式:{"type": "image", "image": "url"} + image_url = f.get("image") + elif f.get("url"): + # 其他格式:{"type": "image", "url": "https://..."} + image_url = f.get("url") + break # 只取第一张图片 + + # 如果 image_url 是服务器中转 URL,直接下载图片转 base64 + # 避免 OSS 签名 URL 在重定向解析过程中被破坏 + if image_url and not image_url.startswith("data:"): + try: + import base64 + from io import BytesIO + from PIL import Image + + MAX_RAW_SIZE = 4 * 1024 * 1024 # 超过 4MB 则压缩 + + async with aiohttp.ClientSession() as _session: + async with _session.get(image_url, allow_redirects=True, timeout=aiohttp.ClientTimeout(total=30)) as _resp: + if _resp.status == 200: + content_type = _resp.headers.get("Content-Type", "image/jpeg") + if content_type.startswith("image/"): + img_bytes = await _resp.read() + original_size = len(img_bytes) + logger.info(f"OpenClaw 下载图片: size={original_size} bytes, type={content_type}") + + if original_size > MAX_RAW_SIZE: + img = Image.open(BytesIO(img_bytes)) + if img.mode in ("RGBA", "P", "LA"): + img = img.convert("RGB") + max_side = 2048 + if max(img.size) > max_side: + img.thumbnail((max_side, max_side), Image.LANCZOS) + buf = BytesIO() + img.save(buf, format="JPEG", quality=75, optimize=True) + img_bytes = buf.getvalue() + content_type = "image/jpeg" + logger.info(f"OpenClaw 图片已压缩: {original_size} -> {len(img_bytes)} bytes") + + b64_data = base64.b64encode(img_bytes).decode("utf-8") + image_url = f"data:{content_type};base64,{b64_data}" + logger.info(f"OpenClaw 图片已转为 base64, size={len(img_bytes)} bytes") + else: + logger.warning(f"OpenClaw 图片 URL 返回非图片类型: {content_type}") + else: + logger.warning(f"OpenClaw 下载图片失败: HTTP {_resp.status}") + except Exception as e: + logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}") + + + if not message: + return ToolResult.error_result( + error="message 参数不能为空", + error_code="OPENCLAW_INVALID_INPUT", + execution_time=time.time() - start_time) + + url = f"{self.base_url.rstrip('/')}/v1/responses" + #请求头 + headers = { + "Authorization": f"Bearer {self._openclaw_token}", + "Content-Type": "application/json", + "x-openclaw-agent-id": self._openclaw_agent_id + } + + # session 路由 + if (self._openclaw_session_strategy == "by_conversation" + and self._conversation_id): + user_field = f"conv-{self._conversation_id}" + else: + user_field = f"user-{self._user_id}" + + # 根据 input_mode 和是否有图片构造 input + input_field = self._build_openclaw_input(message, image_url) + #请求体 + body = { + "model": self._openclaw_model, + "user": user_field, + "input": input_field, + "stream": False + } + + logger.info(f"OpenClaw 请求体: {json.dumps(body, ensure_ascii=False)[:1000]}") + + timeout_config = aiohttp.ClientTimeout(total=self.timeout) + #请求 + async with aiohttp.ClientSession(timeout=timeout_config) as session: + async with session.post(url, json=body, headers=headers) as resp: + execution_time = time.time() - start_time + + if resp.status >= 400: + error_text = await resp.text() + _img_preview2 = (image_url[:100] + "...") if image_url and len(image_url) > 100 else image_url + logger.error( + f"OpenClaw 调用失败: HTTP {resp.status}, " + f"url={url}, agent_id={self._openclaw_agent_id}, " + f"has_image={bool(image_url)}, image_url={_img_preview2}, " + f"input_type={'multimodal' if isinstance(input_field, list) else 'text'}, " + f"error_response={error_text[:1000]}" + ) + return ToolResult.error_result( + error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}", + error_code="OPENCLAW_HTTP_ERROR", + execution_time=execution_time) + + data = await resp.json() + + # 根据 output_mode 解析响应 + result = self._extract_openclaw_response( + data, self._openclaw_output_mode) + display_text = self._format_openclaw_result(result) + + logger.info( + "OpenClaw 调用成功", + extra={ + "tool_id": self.tool_id, + "agent_id": self._openclaw_agent_id, + "has_images": len(result["images"]) > 0, + "execution_time": execution_time + }) + return ToolResult.success_result( + data=display_text, execution_time=execution_time) + + except aiohttp.ClientError as e: + return ToolResult.error_result( + error=f"OpenClaw 网络连接失败: {str(e)}", + error_code="OPENCLAW_NETWORK_ERROR", + execution_time=time.time() - start_time) + except Exception as e: + return ToolResult.error_result( + error=f"OpenClaw 调用失败: {str(e)}", + error_code="OPENCLAW_EXECUTION_ERROR", + execution_time=time.time() - start_time) + + def _build_openclaw_input(self, message: str, image_url: str = None): + """根据 input_mode 和是否有图片构造 OpenClaw input 字段 + + 纯文本模式或无图片 → 返回字符串 + 多模态模式且有图片 → 返回结构化 item 数组 + """ + if not image_url or self._openclaw_input_mode != "multimodal": + return message + + # 构造多模态 content 数组 + content_parts = [ + {"type": "input_text", "text": message} + ] + + if image_url.startswith("data:"): + # base64 data URI: data:image/png;base64,iVBORw0KGgo... + try: + header, data = image_url.split(",", 1) + media_type = header.split(":")[1].split(";")[0] + content_parts.append({ + "type": "input_image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data + } + }) + except (ValueError, IndexError): + logger.warning("无法解析 base64 data URI,回退为纯文本输入") + return message + else: + # URL 引用 + content_parts.append({ + "type": "input_image", + "source": { + "type": "url", + "url": image_url + } + }) + + return [{ + "type": "message", + "role": "user", + "content": content_parts + }] + + @staticmethod + def _extract_openclaw_response(response_data: Dict[str, Any], + output_mode: str = "text") -> Dict[str, Any]: + """从 OpenClaw 响应中提取文本和图片 + + 响应格式: + {"output": [{"type": "message", "content": [ + {"type": "output_text", "text": "..."}, + {"type": "output_image", "image_url": "..."} + ]}]} + + 返回: + {"text": "文本内容", "images": [{"url": "...", "media_type": "image/png"}]} + """ + output = response_data.get("output", []) + texts = [] + images = [] + + for item in output: + if item.get("type") == "message": + for content in item.get("content", []): + content_type = content.get("type") + + if content_type == "output_text": + text = content.get("text", "") + if text: + texts.append(text) + + elif content_type == "output_image" and output_mode == "multimodal": + image_url = content.get("image_url", "") + if image_url: + images.append({ + "url": image_url, + "media_type": content.get("media_type", "image/png") + }) + + text_result = "\n".join(texts) if texts else "" + + # text 模式下只返回文本(向后兼容) + if output_mode == "text": + return {"text": text_result or str(response_data), "images": []} + + return {"text": text_result, "images": images} + + @staticmethod + def _format_openclaw_result(result: Dict[str, Any]) -> str: + """将解析结果格式化为返回给 LLM 的字符串 + + 纯文本 → 直接返回 + 有图片 → 将图片以 Markdown 格式嵌入文本 + """ + text = result.get("text", "") + images = result.get("images", []) + + if not images: + return text or "(OpenClaw 返回了空内容)" + + parts = [] + if text: + parts.append(text) + for i, img in enumerate(images): + parts.append(f"![OpenClaw 生成的图片 {i+1}]({img['url']})") + + return "\n\n".join(parts) + + + #=============openclaw执行函数结束================ def _parse_openapi_schema(self) -> Dict[str, Any]: """解析OpenAPI schema""" operations = {} diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index fb4955b3..34037b12 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -165,7 +165,18 @@ class AppChatService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - + #============为 OpenClaw 工具注入会话session====== + # 为 OpenClaw 工具注入运行时上下文 + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): + if t.tool_instance._is_openclaw: + t.tool_instance._user_id = user_id or "anonymous" + t.tool_instance._conversation_id = ( + str(conversation_id) if conversation_id else None) + # 注入用户上传的文件 + if processed_files: + t.tool_instance._uploaded_files = processed_files + #============为 OpenClaw 工具注入会话session====== # 调用 Agent(支持多模态) result = await agent.chat( message=message, @@ -413,6 +424,17 @@ class AppChatService: processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") + #============为 OpenClaw 工具注入运行时上下文====== + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): + if t.tool_instance._is_openclaw: + t.tool_instance._user_id = user_id or "anonymous" + t.tool_instance._conversation_id = ( + str(conversation_id) if conversation_id else None) + if processed_files: + t.tool_instance._uploaded_files = processed_files + #============为 OpenClaw 工具注入运行时上下文结束====== + # 流式调用 Agent(支持多模态),同时并行启动 TTS full_content = "" full_reasoning = "" diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 978dfdab..62d7ea71 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -640,7 +640,18 @@ class AgentRunService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") - + #================= 为 OpenClaw 工具注入运行时上下文========== + for t in tools: + logger.info(f"检查工具: {type(t).__name__}, has_tool_instance={hasattr(t, 'tool_instance')}, is_openclaw={getattr(getattr(t, 'tool_instance', None), '_is_openclaw', 'N/A')}") + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): + if t.tool_instance._is_openclaw: + t.tool_instance._user_id = user_id or "anonymous" + t.tool_instance._conversation_id = ( + str(conversation_id) if conversation_id else None) + if processed_files: + t.tool_instance._uploaded_files = processed_files + logger.info(f"已注入 _uploaded_files, 数量: {len(processed_files)}") + #================= 为 OpenClaw 工具注入运行时上下文结束========== # 7. 知识库检索 context = None @@ -890,7 +901,18 @@ class AgentRunService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") - + #============为 OpenClaw 工具注入会话session====== + # 为 OpenClaw 工具注入运行时上下文 + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): + if t.tool_instance._is_openclaw: + t.tool_instance._user_id = user_id or "anonymous" + t.tool_instance._conversation_id = ( + str(conversation_id) if conversation_id else None) + # 注入用户上传的文件 + if processed_files: + t.tool_instance._uploaded_files = processed_files + #============为 OpenClaw 工具注入会话session====== # 7. 知识库检索 context = None diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 089f0ec5..0f88a65e 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -330,6 +330,20 @@ class ToolService: if config.tool_type == ToolType.MCP.value: return await self._test_mcp_connection(config) elif config.tool_type == ToolType.CUSTOM.value: + # ========== 测试工具连接 OpenClaw 特判 ========== + custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) + if custom_config and custom_config.schema_content: + schema = custom_config.schema_content + if isinstance(schema, str): + try: + schema = json.loads(schema) + except json.JSONDecodeError: + schema = {} + #请求头中包含OpenClaw字段 + if isinstance(schema, dict) and schema.get("info", {}).get("x-openclaw"): + return await self._test_openclaw_connection(custom_config, schema) + # ========== OpenClaw 特判结束 ========== + #正常自定义工具逻辑 return await self._test_custom_connection(config) elif config.tool_type == ToolType.BUILTIN.value: return await self._test_builtin_connection(config) @@ -339,6 +353,45 @@ class ToolService: except Exception as e: return {"success": False, "message": f"测试失败: {str(e)}"} + #=============测试openclaw连接 特判=============== + async def _test_openclaw_connection( + self, custom_config: CustomToolConfig, schema: dict + ) -> Dict[str, Any]: + """测试 OpenClaw 连接""" + try: + info = schema.get("info", {}) + servers = schema.get("servers", []) + base_url = servers[0]["url"] if servers else "" + token = (custom_config.auth_config or {}).get("token", "") + agent_id = info.get("x-openclaw-agent-id", "main") + model = info.get("x-openclaw-default-model", "openclaw") + + url = f"{base_url.rstrip('/')}/v1/responses" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "x-openclaw-agent-id": agent_id + } + body = { + "model": model, + "user": "connection-test", + "input": "hi", + "stream": False + } + + timeout_config = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout_config) as session: + async with session.post(url, json=body, headers=headers) as resp: + if resp.status < 400: + return {"success": True, "message": "OpenClaw 连接成功"} + error_text = await resp.text() + return { + "success": False, + "message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}" + } + except Exception as e: + return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"} + #=============测试openclaw连接结束=========== def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID): """确保内置工具已初始化""" existing = self.tool_repo.exists_builtin_for_tenant(self.db, tenant_id) @@ -1139,6 +1192,27 @@ class ToolService: custom_config = self.db.query(CustomToolConfig).filter( CustomToolConfig.id == tool_config.id ).first() + # ========== 更新工具 OpenClaw 特判 ========== + if custom_config and custom_config.schema_content: + schema = custom_config.schema_content + if isinstance(schema, str): + try: + schema = json.loads(schema) + except json.JSONDecodeError: + schema = {} + info = schema.get("info", {}) if isinstance(schema, dict) else {} + if info.get("x-openclaw"): + servers = schema.get("servers", []) + has_url = bool(servers and servers[0].get("url")) + has_agent_id = bool(info.get("x-openclaw-agent-id")) + has_token = bool(custom_config.auth_config + and custom_config.auth_config.get("api_key")) + if has_url and has_agent_id and has_token: + tool_config.status = ToolStatus.AVAILABLE.value + else: + tool_config.status = ToolStatus.UNCONFIGURED.value + return + # ========== OpenClaw 特判结束 ========== if custom_config and tool_config.name and (custom_config.schema_content or custom_config.schema_url): tool_config.status = ToolStatus.AVAILABLE.value From 562ca6c1f15fb3abb9868b3b20ed05f5b228f25b Mon Sep 17 00:00:00 2001 From: miao <1468212639@qq.com> Date: Tue, 7 Apr 2026 13:58:38 +0800 Subject: [PATCH 11/85] fix(tools): fix OpenClaw connection test and multimodal format compatibility - Use safe .get() for server URL to avoid KeyError - Support both api_key and token in connection test auth - Add OpenAI/Volcano image format (image_url) support - Add aiohttp import in _test_openclaw_connection --- api/app/core/tools/custom/base.py | 10 ++++++++-- api/app/services/draft_run_service.py | 1 - api/app/services/tool_service.py | 8 ++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index 3f3daad7..c7858a7b 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -220,7 +220,9 @@ class CustomTool(BaseTool): image_url = None if self._uploaded_files: for f in self._uploaded_files: - if f.get("type") == "image": + f_type = f.get("type", "") + if f_type == "image": + # Bedrock/Anthropic 格式:{"type": "image", "source": {"type": "base64", ...}} source = f.get("source", {}) if source.get("type") == "base64": media_type = source.get("media_type", "image/jpeg") @@ -232,7 +234,11 @@ class CustomTool(BaseTool): elif f.get("url"): # 其他格式:{"type": "image", "url": "https://..."} image_url = f.get("url") - break # 只取第一张图片 + break + elif f_type == "image_url": + # OpenAI/Volcano 格式:{"type": "image_url", "image_url": {"url": "..."}} + image_url = f.get("image_url", {}).get("url", "") + break # 如果 image_url 是服务器中转 URL,直接下载图片转 base64 # 避免 OSS 签名 URL 在重定向解析过程中被破坏 diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 62d7ea71..fa307ec5 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -642,7 +642,6 @@ class AgentRunService: logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") #================= 为 OpenClaw 工具注入运行时上下文========== for t in tools: - logger.info(f"检查工具: {type(t).__name__}, has_tool_instance={hasattr(t, 'tool_instance')}, is_openclaw={getattr(getattr(t, 'tool_instance', None), '_is_openclaw', 'N/A')}") if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): if t.tool_instance._is_openclaw: t.tool_instance._user_id = user_id or "anonymous" diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 0f88a65e..9c9faf69 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -358,11 +358,15 @@ class ToolService: self, custom_config: CustomToolConfig, schema: dict ) -> Dict[str, Any]: """测试 OpenClaw 连接""" + import aiohttp try: info = schema.get("info", {}) servers = schema.get("servers", []) - base_url = servers[0]["url"] if servers else "" - token = (custom_config.auth_config or {}).get("token", "") + base_url = servers[0].get("url", "") if servers else "" + if not base_url: + return {"success": False, "message": "OpenClaw 未配置 server URL"} + auth = custom_config.auth_config or {} + token = auth.get("api_key") or auth.get("token") or "" agent_id = info.get("x-openclaw-agent-id", "main") model = info.get("x-openclaw-default-model", "openclaw") From 55b2e05ba8fb2fcd623bffed37fe0b9962d2f3da Mon Sep 17 00:00:00 2001 From: miao <1468212639@qq.com> Date: Thu, 9 Apr 2026 18:13:21 +0800 Subject: [PATCH 12/85] feat(tools): refactor migrate OpenClaw from custom tool to builtin tool Create OpenClawTool class inheriting BuiltinTool with dedicated config Remove all x-openclaw special handling from CustomTool (~270 lines) Add multi-operation support (print_task, device_query, image_understand, general) Change ensure_builtin_tools_initialized to incremental mode for auto-provisioning Fix OperationTool and LangchainAdapter to support OpenClaw operation routing --- api/app/core/tools/builtin/openclaw_tool.py | 298 +++++++++++++++ api/app/core/tools/builtin/operation_tool.py | 60 +++ .../tools/configs/builtin/openclaw_tool.json | 15 + api/app/core/tools/configs/builtin_tools.json | 13 + api/app/core/tools/custom/base.py | 346 +----------------- api/app/core/tools/langchain_adapter.py | 2 +- api/app/repositories/tool_repository.py | 11 + api/app/services/app_chat_service.py | 34 +- api/app/services/draft_run_service.py | 35 +- api/app/services/tool_service.py | 164 ++++----- 10 files changed, 503 insertions(+), 475 deletions(-) create mode 100644 api/app/core/tools/builtin/openclaw_tool.py create mode 100644 api/app/core/tools/configs/builtin/openclaw_tool.json diff --git a/api/app/core/tools/builtin/openclaw_tool.py b/api/app/core/tools/builtin/openclaw_tool.py new file mode 100644 index 00000000..161769f1 --- /dev/null +++ b/api/app/core/tools/builtin/openclaw_tool.py @@ -0,0 +1,298 @@ +"""OpenClaw 远程 Agent 内置工具""" +import time +import base64 +from io import BytesIO +from typing import List, Dict, Any, Optional +import aiohttp + +from app.core.tools.builtin.base import BuiltinTool +from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class OpenClawTool(BuiltinTool): + """OpenClaw 远程 Agent 工具 — 支持文本和图片多模态输入""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + super().__init__(tool_id, config) + params = self.parameters_config + + # 用户配置项(前端表单填写) + self._server_url = params.get("server_url", "") + self._api_key = params.get("api_key", "") + self._agent_id = params.get("agent_id", "main") + + # 内部默认值 + self._model = "openclaw" + self._session_strategy = "by_user" + self._timeout = 120 + + # 运行时上下文(通过 set_runtime_context 注入) + self._user_id = "anonymous" + self._conversation_id = None + self._uploaded_files = [] + + @property + def name(self) -> str: + return "openclaw_tool" + + @property + def description(self) -> str: + return ( + "OpenClaw 远程 Agent:将任务委托给远程 OpenClaw Agent。" + "具备 3D 模型生成与打印控制、设备管理、文件处理、浏览器自动化、" + "Shell 命令执行、网络搜索等能力。支持文本和图片多模态交互。" + ) + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="任务类型", + required=True, + enum= ["print_task", "device_query", "image_understand", "general"] + ), + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw Agent 的文本请求内容", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="可选,附带的图片 URL 或 base64 data URI(OpenClaw 支持图片输入)", + required=False + ) + ] + + # ---------- 运行时上下文注入 ---------- + def set_runtime_context( + self, + user_id: str = "anonymous", + conversation_id: Optional[str] = None, + uploaded_files: Optional[list] = None + ): + """注入运行时上下文(由 chat service 调用)""" + self._user_id = user_id + self._conversation_id = conversation_id + self._uploaded_files = uploaded_files or [] + + # ---------- 连接测试 ---------- + async def test_connection(self) -> Dict[str, Any]: + """测试 OpenClaw Gateway 连接""" + if not self._server_url: + return {"success": False, "message": "未配置 server_url"} + if not self._api_key: + return {"success": False, "message": "未配置 api_key"} + + url = f"{self._server_url.rstrip('/')}/v1/responses" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + "x-openclaw-agent-id": self._agent_id + } + body = { + "model": self._model, + "user": "connection-test", + "input": "hi", + "stream": False + } + try: + timeout_cfg = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.post(url, json=body, headers=headers) as resp: + if resp.status < 400: + return {"success": True, "message": "OpenClaw 连接成功"} + error_text = await resp.text() + return { + "success": False, + "message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}" + } + except Exception as e: + return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"} + + # ---------- 执行 ---------- + async def execute(self, **kwargs) -> ToolResult: + """执行 OpenClaw 调用""" + start_time = time.time() + try: + message = kwargs.get("message", "") + operation = kwargs.get("operation", "unknown") + if not message: + return ToolResult.error_result( + error="message 参数不能为空", + error_code="OPENCLAW_INVALID_INPUT", + execution_time=time.time() - start_time + ) + + # 提取图片:优先从用户上传文件中获取,LLM 传的 image_url 作为兜底 + image_url = self._extract_image_from_uploads() + if not image_url: + image_url = kwargs.get("image_url") + if image_url and not image_url.startswith("data:"): + image_url = await self._download_and_encode_image(image_url) + + # 构建请求 + url = f"{self._server_url.rstrip('/')}/v1/responses" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + "x-openclaw-agent-id": self._agent_id + } + user_field = ( + f"conv-{self._conversation_id}" + if self._session_strategy == "by_conversation" and self._conversation_id + else f"user-{self._user_id}" + ) + input_field = self._build_input(message, image_url) + body = { + "model": self._model, + "user": user_field, + "input": input_field, + "stream": False + } + + timeout_cfg = aiohttp.ClientTimeout(total=self._timeout) + # 打印请求日志(截断 base64 避免日志过大) + log_body = {**body} + if isinstance(log_body.get("input"), list): + log_body["input"] = "[multimodal input, truncated]" + elif isinstance(log_body.get("input"), str) and len(log_body["input"]) > 500: + log_body["input"] = log_body["input"][:500] + "..." + logger.info( + f"OpenClaw 请求: url={url}, agent_id={self._agent_id}, " + f"has_image={bool(image_url)}, body={log_body}" + ) + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.post(url, json=body, headers=headers) as resp: + execution_time = time.time() - start_time + if resp.status >= 400: + error_text = await resp.text() + return ToolResult.error_result( + error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}", + error_code="OPENCLAW_HTTP_ERROR", + execution_time=execution_time + ) + data = await resp.json() + text = self._extract_response(data) + display_text = self._format_result(text) + return ToolResult.success_result( + data=display_text, + execution_time=execution_time + ) + + except aiohttp.ClientError as e: + return ToolResult.error_result( + error=f"OpenClaw 网络连接失败: {str(e)}", + error_code="OPENCLAW_NETWORK_ERROR", + execution_time=time.time() - start_time + ) + except Exception as e: + return ToolResult.error_result( + error=f"OpenClaw 调用失败: {str(e)}", + error_code="OPENCLAW_EXECUTION_ERROR", + execution_time=time.time() - start_time + ) + + # ---------- 私有方法 ---------- + def _extract_image_from_uploads(self) -> Optional[str]: + """从用户上传文件中提取图片 URL""" + for f in self._uploaded_files: + f_type = f.get("type", "") + if f_type == "image": + source = f.get("source", {}) + if source.get("type") == "base64": + media_type = source.get("media_type", "image/jpeg") + data = source.get("data", "") + return f"data:{media_type};base64,{data}" + elif f.get("image"): + return f.get("image") + elif f.get("url"): + return f.get("url") + elif f_type == "image_url": + return f.get("image_url", {}).get("url", "") + return None + + async def _download_and_encode_image(self, image_url: str) -> str: + """下载图片并转为 base64 data URI""" + try: + from PIL import Image + MAX_RAW_SIZE = 4 * 1024 * 1024 + + async with aiohttp.ClientSession() as session: + async with session.get( + image_url, allow_redirects=True, + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + if resp.status != 200: + return image_url + content_type = resp.headers.get("Content-Type", "image/jpeg") + if not content_type.startswith("image/"): + return image_url + img_bytes = await resp.read() + + if len(img_bytes) > MAX_RAW_SIZE: + img = Image.open(BytesIO(img_bytes)) + if img.mode in ("RGBA", "P", "LA"): + img = img.convert("RGB") + if max(img.size) > 2048: + img.thumbnail((2048, 2048), Image.LANCZOS) + buf = BytesIO() + img.save(buf, format="JPEG", quality=75, optimize=True) + img_bytes = buf.getvalue() + content_type = "image/jpeg" + + b64 = base64.b64encode(img_bytes).decode("utf-8") + return f"data:{content_type};base64,{b64}" + except Exception as e: + logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}") + return image_url + + def _build_input(self, message: str, image_url: Optional[str] = None): + """构造请求 input 字段:有图片则构造多模态结构,否则纯文本""" + if not image_url: + return message + + content_parts = [{"type": "input_text", "text": message}] + if image_url.startswith("data:"): + try: + header, data = image_url.split(",", 1) + media_type = header.split(":")[1].split(";")[0] + content_parts.append({ + "type": "input_image", + "source": {"type": "base64", "media_type": media_type, "data": data} + }) + except (ValueError, IndexError): + return message + else: + content_parts.append({ + "type": "input_image", + "source": {"type": "url", "url": image_url} + }) + + return [{"type": "message", "role": "user", "content": content_parts}] + + def _extract_response(self, response_data: Dict[str, Any]) -> str: + """从 OpenClaw 响应中提取文本内容 + + OpenClaw /v1/responses 只返回 output_text 类型的内容。 + 图片信息(如有)由 OpenClaw Skill 以 Markdown 链接形式嵌入文本中返回。 + """ + output = response_data.get("output", []) + texts = [] + for item in output: + if item.get("type") == "message": + for content in item.get("content", []): + if content.get("type") == "output_text" and content.get("text"): + texts.append(content["text"]) + return "\n".join(texts) if texts else str(response_data) + + @staticmethod + def _format_result(text: str) -> str: + """格式化结果为 LLM 可读字符串""" + return text or "(OpenClaw 返回了空内容)" diff --git a/api/app/core/tools/builtin/operation_tool.py b/api/app/core/tools/builtin/operation_tool.py index 126541a8..495551af 100644 --- a/api/app/core/tools/builtin/operation_tool.py +++ b/api/app/core/tools/builtin/operation_tool.py @@ -32,6 +32,8 @@ class OperationTool(BaseTool): return self._get_datetime_params() elif self.base_tool.name == 'json_tool': return self._get_json_params() + elif self.base_tool.name == 'openclaw_tool': + return self._get_openclaw_params() else: # 默认返回除operation外的所有参数 return [p for p in self.base_tool.parameters if p.name != "operation"] @@ -209,6 +211,64 @@ class OperationTool(BaseTool): else: return base_params + def _get_openclaw_params(self) -> List[ToolParameter]: + """获取 openclaw_tool 特定操作的参数""" + if self.operation == "print_task": + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型", + required=False + ) + ] + elif self.operation == "device_query": + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw 的设备查询指令", + required=True + ) + ] + elif self.operation == "image_understand": + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="必须提供,要分析的图片 URL 或 base64 data URI", + required=True + ) + ] + else: + # general 及其他 + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="可选,附带的图片 URL 或 base64 data URI", + required=False + ) + ] + async def execute(self, **kwargs) -> ToolResult: """执行特定操作""" # 添加operation参数 diff --git a/api/app/core/tools/configs/builtin/openclaw_tool.json b/api/app/core/tools/configs/builtin/openclaw_tool.json new file mode 100644 index 00000000..7c1f9629 --- /dev/null +++ b/api/app/core/tools/configs/builtin/openclaw_tool.json @@ -0,0 +1,15 @@ +{ + "name": "openclaw_tool", + "description": "调用OpenClaw Agent远程服务", + "tool_class": "OpenClawTool", + "category": "agent", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "server_url": "", + "api_key": "", + "agent_id": "main" + }, + "tags": ["agent", "openclaw", "multimodal", "3d-printing", "builtin"] +} diff --git a/api/app/core/tools/configs/builtin_tools.json b/api/app/core/tools/configs/builtin_tools.json index 79206a5e..882a970a 100644 --- a/api/app/core/tools/configs/builtin_tools.json +++ b/api/app/core/tools/configs/builtin_tools.json @@ -30,5 +30,18 @@ "parameters": { "api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true} } + }, + "openclaw": { + "name": "OpenClaw远程Agent", + "description": "OpenClaw Agent远程服务", + "tool_class": "OpenClawTool", + "category": "agent", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "server_url": {"type": "string", "description": "OpenClaw Gateway 地址", "required": true}, + "api_key": {"type": "string", "description": "OpenClaw API Key", "sensitive": true, "required": true} + } } } \ No newline at end of file diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index c7858a7b..f6e191ed 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -31,7 +31,7 @@ class CustomTool(BaseTool): self.base_url = config.get("base_url", "") self.timeout = config.get("timeout", 30) - #===========OpenClaw特殊判断(取到OpenClaw特殊配置)========== + # 解析 schema schema = self.schema_content if isinstance(schema, str): try: @@ -39,51 +39,7 @@ class CustomTool(BaseTool): self.schema_content = schema except json.JSONDecodeError: schema = {} - - info = schema.get("info", {}) if isinstance(schema, dict) else {} - self._is_openclaw = info.get("x-openclaw", False) - - if self._is_openclaw: - # 从扩展字段读取 OpenClaw 配置 - self._openclaw_agent_id = info.get("x-openclaw-agent-id", "main") - self._openclaw_model = info.get("x-openclaw-default-model", "openclaw") - self._openclaw_session_strategy = info.get( - "x-openclaw-session-strategy", "by_user") - self._openclaw_timeout = info.get("x-openclaw-timeout", 60) - self._openclaw_input_mode = info.get("x-openclaw-input-mode", "text") - self._openclaw_output_mode = info.get("x-openclaw-output-mode", "text") - - # 从 servers 读取 base_url - servers = schema.get("servers", []) - if servers: - self.base_url = servers[0].get("url", "") - - # 从 auth_config 读取 token(兼容 api_key 和 bearer_token 两种认证方式) - self._openclaw_token = ( - self.auth_config.get("api_key") # api_key 认证方式 - or self.auth_config.get("token") # bearer_token 认证方式 - or "" - ) - - # 覆盖 timeout - self.timeout = self._openclaw_timeout - - # 运行时上下文(后续注入) - self._user_id = "anonymous" - self._conversation_id = None - self._uploaded_files = [] # 新增:用户上传的文件 - - # 跳过 Schema 解析 - self._parsed_operations = {} - - logger.info( - f"检测到 OpenClaw 工具: agent_id={self._openclaw_agent_id}, " - f"base_url={self.base_url}, " - f"input_mode={self._openclaw_input_mode}, " - f"output_mode={self._openclaw_output_mode}") - else: - # 解析schema - self._parsed_operations = self._parse_openapi_schema() + self._parsed_operations = self._parse_openapi_schema() @property def name(self) -> str: @@ -109,31 +65,6 @@ class CustomTool(BaseTool): @property def parameters(self) -> List[ToolParameter]: """工具参数定义""" - # ========== OpenClaw 特判 根据输入模式解析是否需要image_url ========== - if self._is_openclaw: - params = [ - ToolParameter( - name="message", - type=ParameterType.STRING, - description="发送给 OpenClaw Agent 的文本请求内容", - required=True - ) - ] - # 多模态输入模式下,增加 image_url 参数 - if self._openclaw_input_mode == "multimodal": - params.append(ToolParameter( - name="image_url", - type=ParameterType.STRING, - description=( - "可选,附带的图片URL或base64 data URI" - "(如 data:image/png;base64,...)。" - "传入后 Agent 可以理解图片内容。" - ), - required=False - )) - return params - # ========== 特判结束 ========== - params = [] # 添加操作选择参数 @@ -166,10 +97,6 @@ class CustomTool(BaseTool): async def execute(self, **kwargs) -> ToolResult: """执行自定义工具""" - # ========== OpenClaw 特判 ========== - if self._is_openclaw: - return await self._execute_openclaw(**kwargs) - # ========== 特判结束 ========== start_time = time.time() try: @@ -210,275 +137,6 @@ class CustomTool(BaseTool): execution_time=execution_time ) - #=============openclaw执行函数开始=============== - async def _execute_openclaw(self, **kwargs) -> ToolResult: - """OpenClaw 专属执行逻辑(支持多模态输入)""" - start_time = time.time() - try: - message = kwargs.get("message", "") - # 从用户实际上传的文件中提取图片 URL - image_url = None - if self._uploaded_files: - for f in self._uploaded_files: - f_type = f.get("type", "") - if f_type == "image": - # Bedrock/Anthropic 格式:{"type": "image", "source": {"type": "base64", ...}} - source = f.get("source", {}) - if source.get("type") == "base64": - media_type = source.get("media_type", "image/jpeg") - data = source.get("data", "") - image_url = f"data:{media_type};base64,{data}" - elif f.get("image"): - # DashScope 格式:{"type": "image", "image": "url"} - image_url = f.get("image") - elif f.get("url"): - # 其他格式:{"type": "image", "url": "https://..."} - image_url = f.get("url") - break - elif f_type == "image_url": - # OpenAI/Volcano 格式:{"type": "image_url", "image_url": {"url": "..."}} - image_url = f.get("image_url", {}).get("url", "") - break - - # 如果 image_url 是服务器中转 URL,直接下载图片转 base64 - # 避免 OSS 签名 URL 在重定向解析过程中被破坏 - if image_url and not image_url.startswith("data:"): - try: - import base64 - from io import BytesIO - from PIL import Image - - MAX_RAW_SIZE = 4 * 1024 * 1024 # 超过 4MB 则压缩 - - async with aiohttp.ClientSession() as _session: - async with _session.get(image_url, allow_redirects=True, timeout=aiohttp.ClientTimeout(total=30)) as _resp: - if _resp.status == 200: - content_type = _resp.headers.get("Content-Type", "image/jpeg") - if content_type.startswith("image/"): - img_bytes = await _resp.read() - original_size = len(img_bytes) - logger.info(f"OpenClaw 下载图片: size={original_size} bytes, type={content_type}") - - if original_size > MAX_RAW_SIZE: - img = Image.open(BytesIO(img_bytes)) - if img.mode in ("RGBA", "P", "LA"): - img = img.convert("RGB") - max_side = 2048 - if max(img.size) > max_side: - img.thumbnail((max_side, max_side), Image.LANCZOS) - buf = BytesIO() - img.save(buf, format="JPEG", quality=75, optimize=True) - img_bytes = buf.getvalue() - content_type = "image/jpeg" - logger.info(f"OpenClaw 图片已压缩: {original_size} -> {len(img_bytes)} bytes") - - b64_data = base64.b64encode(img_bytes).decode("utf-8") - image_url = f"data:{content_type};base64,{b64_data}" - logger.info(f"OpenClaw 图片已转为 base64, size={len(img_bytes)} bytes") - else: - logger.warning(f"OpenClaw 图片 URL 返回非图片类型: {content_type}") - else: - logger.warning(f"OpenClaw 下载图片失败: HTTP {_resp.status}") - except Exception as e: - logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}") - - - if not message: - return ToolResult.error_result( - error="message 参数不能为空", - error_code="OPENCLAW_INVALID_INPUT", - execution_time=time.time() - start_time) - - url = f"{self.base_url.rstrip('/')}/v1/responses" - #请求头 - headers = { - "Authorization": f"Bearer {self._openclaw_token}", - "Content-Type": "application/json", - "x-openclaw-agent-id": self._openclaw_agent_id - } - - # session 路由 - if (self._openclaw_session_strategy == "by_conversation" - and self._conversation_id): - user_field = f"conv-{self._conversation_id}" - else: - user_field = f"user-{self._user_id}" - - # 根据 input_mode 和是否有图片构造 input - input_field = self._build_openclaw_input(message, image_url) - #请求体 - body = { - "model": self._openclaw_model, - "user": user_field, - "input": input_field, - "stream": False - } - - logger.info(f"OpenClaw 请求体: {json.dumps(body, ensure_ascii=False)[:1000]}") - - timeout_config = aiohttp.ClientTimeout(total=self.timeout) - #请求 - async with aiohttp.ClientSession(timeout=timeout_config) as session: - async with session.post(url, json=body, headers=headers) as resp: - execution_time = time.time() - start_time - - if resp.status >= 400: - error_text = await resp.text() - _img_preview2 = (image_url[:100] + "...") if image_url and len(image_url) > 100 else image_url - logger.error( - f"OpenClaw 调用失败: HTTP {resp.status}, " - f"url={url}, agent_id={self._openclaw_agent_id}, " - f"has_image={bool(image_url)}, image_url={_img_preview2}, " - f"input_type={'multimodal' if isinstance(input_field, list) else 'text'}, " - f"error_response={error_text[:1000]}" - ) - return ToolResult.error_result( - error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}", - error_code="OPENCLAW_HTTP_ERROR", - execution_time=execution_time) - - data = await resp.json() - - # 根据 output_mode 解析响应 - result = self._extract_openclaw_response( - data, self._openclaw_output_mode) - display_text = self._format_openclaw_result(result) - - logger.info( - "OpenClaw 调用成功", - extra={ - "tool_id": self.tool_id, - "agent_id": self._openclaw_agent_id, - "has_images": len(result["images"]) > 0, - "execution_time": execution_time - }) - return ToolResult.success_result( - data=display_text, execution_time=execution_time) - - except aiohttp.ClientError as e: - return ToolResult.error_result( - error=f"OpenClaw 网络连接失败: {str(e)}", - error_code="OPENCLAW_NETWORK_ERROR", - execution_time=time.time() - start_time) - except Exception as e: - return ToolResult.error_result( - error=f"OpenClaw 调用失败: {str(e)}", - error_code="OPENCLAW_EXECUTION_ERROR", - execution_time=time.time() - start_time) - - def _build_openclaw_input(self, message: str, image_url: str = None): - """根据 input_mode 和是否有图片构造 OpenClaw input 字段 - - 纯文本模式或无图片 → 返回字符串 - 多模态模式且有图片 → 返回结构化 item 数组 - """ - if not image_url or self._openclaw_input_mode != "multimodal": - return message - - # 构造多模态 content 数组 - content_parts = [ - {"type": "input_text", "text": message} - ] - - if image_url.startswith("data:"): - # base64 data URI: data:image/png;base64,iVBORw0KGgo... - try: - header, data = image_url.split(",", 1) - media_type = header.split(":")[1].split(";")[0] - content_parts.append({ - "type": "input_image", - "source": { - "type": "base64", - "media_type": media_type, - "data": data - } - }) - except (ValueError, IndexError): - logger.warning("无法解析 base64 data URI,回退为纯文本输入") - return message - else: - # URL 引用 - content_parts.append({ - "type": "input_image", - "source": { - "type": "url", - "url": image_url - } - }) - - return [{ - "type": "message", - "role": "user", - "content": content_parts - }] - - @staticmethod - def _extract_openclaw_response(response_data: Dict[str, Any], - output_mode: str = "text") -> Dict[str, Any]: - """从 OpenClaw 响应中提取文本和图片 - - 响应格式: - {"output": [{"type": "message", "content": [ - {"type": "output_text", "text": "..."}, - {"type": "output_image", "image_url": "..."} - ]}]} - - 返回: - {"text": "文本内容", "images": [{"url": "...", "media_type": "image/png"}]} - """ - output = response_data.get("output", []) - texts = [] - images = [] - - for item in output: - if item.get("type") == "message": - for content in item.get("content", []): - content_type = content.get("type") - - if content_type == "output_text": - text = content.get("text", "") - if text: - texts.append(text) - - elif content_type == "output_image" and output_mode == "multimodal": - image_url = content.get("image_url", "") - if image_url: - images.append({ - "url": image_url, - "media_type": content.get("media_type", "image/png") - }) - - text_result = "\n".join(texts) if texts else "" - - # text 模式下只返回文本(向后兼容) - if output_mode == "text": - return {"text": text_result or str(response_data), "images": []} - - return {"text": text_result, "images": images} - - @staticmethod - def _format_openclaw_result(result: Dict[str, Any]) -> str: - """将解析结果格式化为返回给 LLM 的字符串 - - 纯文本 → 直接返回 - 有图片 → 将图片以 Markdown 格式嵌入文本 - """ - text = result.get("text", "") - images = result.get("images", []) - - if not images: - return text or "(OpenClaw 返回了空内容)" - - parts = [] - if text: - parts.append(text) - for i, img in enumerate(images): - parts.append(f"![OpenClaw 生成的图片 {i+1}]({img['url']})") - - return "\n\n".join(parts) - - - #=============openclaw执行函数结束================ def _parse_openapi_schema(self) -> Dict[str, Any]: """解析OpenAPI schema""" operations = {} diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index 51415732..859b6312 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -131,7 +131,7 @@ class LangchainAdapter: def _tool_supports_operations(tool: BaseTool) -> bool: """检查工具是否支持多操作""" # 内置工具中支持操作的工具 - builtin_operation_tools = ['datetime_tool', 'json_tool'] + builtin_operation_tools = ['datetime_tool', 'json_tool', 'openclaw_tool'] # 检查内置工具 if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools: diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 1a9b0b87..1348c4e8 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -161,6 +161,17 @@ class BuiltinToolRepository: BuiltinToolConfig.id == tool_id ).first() + @staticmethod + def get_existing_tool_classes(db: Session, tenant_id: uuid.UUID) -> set: + """获取该租户已有的内置工具 tool_class 集合""" + rows = db.query(BuiltinToolConfig.tool_class).join( + ToolConfig, BuiltinToolConfig.id == ToolConfig.id + ).filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.tool_type == ToolType.BUILTIN.value + ).all() + return {row[0] for row in rows} + class CustomToolRepository: """自定义工具仓储类""" diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 34037b12..ec0c4b79 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -165,18 +165,14 @@ class AppChatService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - #============为 OpenClaw 工具注入会话session====== - # 为 OpenClaw 工具注入运行时上下文 + # 为需要运行时上下文的工具注入上下文 for t in tools: - if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): - if t.tool_instance._is_openclaw: - t.tool_instance._user_id = user_id or "anonymous" - t.tool_instance._conversation_id = ( - str(conversation_id) if conversation_id else None) - # 注入用户上传的文件 - if processed_files: - t.tool_instance._uploaded_files = processed_files - #============为 OpenClaw 工具注入会话session====== + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 调用 Agent(支持多模态) result = await agent.chat( message=message, @@ -424,16 +420,14 @@ class AppChatService: processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - #============为 OpenClaw 工具注入运行时上下文====== + # 为需要运行时上下文的工具注入上下文 for t in tools: - if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): - if t.tool_instance._is_openclaw: - t.tool_instance._user_id = user_id or "anonymous" - t.tool_instance._conversation_id = ( - str(conversation_id) if conversation_id else None) - if processed_files: - t.tool_instance._uploaded_files = processed_files - #============为 OpenClaw 工具注入运行时上下文结束====== + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 流式调用 Agent(支持多模态),同时并行启动 TTS full_content = "" diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index fa307ec5..5c10e4f8 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -640,17 +640,14 @@ class AgentRunService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") - #================= 为 OpenClaw 工具注入运行时上下文========== + # 为需要运行时上下文的工具注入上下文 for t in tools: - if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): - if t.tool_instance._is_openclaw: - t.tool_instance._user_id = user_id or "anonymous" - t.tool_instance._conversation_id = ( - str(conversation_id) if conversation_id else None) - if processed_files: - t.tool_instance._uploaded_files = processed_files - logger.info(f"已注入 _uploaded_files, 数量: {len(processed_files)}") - #================= 为 OpenClaw 工具注入运行时上下文结束========== + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 7. 知识库检索 context = None @@ -900,18 +897,14 @@ class AgentRunService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") - #============为 OpenClaw 工具注入会话session====== - # 为 OpenClaw 工具注入运行时上下文 + # 为需要运行时上下文的工具注入上下文 for t in tools: - if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, '_is_openclaw'): - if t.tool_instance._is_openclaw: - t.tool_instance._user_id = user_id or "anonymous" - t.tool_instance._conversation_id = ( - str(conversation_id) if conversation_id else None) - # 注入用户上传的文件 - if processed_files: - t.tool_instance._uploaded_files = processed_files - #============为 OpenClaw 工具注入会话session====== + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 7. 知识库检索 context = None diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 9c9faf69..20961119 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -34,7 +34,8 @@ BUILTIN_TOOLS = { "JsonTool": "app.core.tools.builtin.json_tool", "BaiduSearchTool": "app.core.tools.builtin.baidu_search_tool", "MinerUTool": "app.core.tools.builtin.mineru_tool", - "TextInTool": "app.core.tools.builtin.textin_tool" + "TextInTool": "app.core.tools.builtin.textin_tool", + "OpenClawTool": "app.core.tools.builtin.openclaw_tool", } @@ -330,20 +331,6 @@ class ToolService: if config.tool_type == ToolType.MCP.value: return await self._test_mcp_connection(config) elif config.tool_type == ToolType.CUSTOM.value: - # ========== 测试工具连接 OpenClaw 特判 ========== - custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) - if custom_config and custom_config.schema_content: - schema = custom_config.schema_content - if isinstance(schema, str): - try: - schema = json.loads(schema) - except json.JSONDecodeError: - schema = {} - #请求头中包含OpenClaw字段 - if isinstance(schema, dict) and schema.get("info", {}).get("x-openclaw"): - return await self._test_openclaw_connection(custom_config, schema) - # ========== OpenClaw 特判结束 ========== - #正常自定义工具逻辑 return await self._test_custom_connection(config) elif config.tool_type == ToolType.BUILTIN.value: return await self._test_builtin_connection(config) @@ -353,62 +340,19 @@ class ToolService: except Exception as e: return {"success": False, "message": f"测试失败: {str(e)}"} - #=============测试openclaw连接 特判=============== - async def _test_openclaw_connection( - self, custom_config: CustomToolConfig, schema: dict - ) -> Dict[str, Any]: - """测试 OpenClaw 连接""" - import aiohttp - try: - info = schema.get("info", {}) - servers = schema.get("servers", []) - base_url = servers[0].get("url", "") if servers else "" - if not base_url: - return {"success": False, "message": "OpenClaw 未配置 server URL"} - auth = custom_config.auth_config or {} - token = auth.get("api_key") or auth.get("token") or "" - agent_id = info.get("x-openclaw-agent-id", "main") - model = info.get("x-openclaw-default-model", "openclaw") - - url = f"{base_url.rstrip('/')}/v1/responses" - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - "x-openclaw-agent-id": agent_id - } - body = { - "model": model, - "user": "connection-test", - "input": "hi", - "stream": False - } - - timeout_config = aiohttp.ClientTimeout(total=30) - async with aiohttp.ClientSession(timeout=timeout_config) as session: - async with session.post(url, json=body, headers=headers) as resp: - if resp.status < 400: - return {"success": True, "message": "OpenClaw 连接成功"} - error_text = await resp.text() - return { - "success": False, - "message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}" - } - except Exception as e: - return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"} - #=============测试openclaw连接结束=========== def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID): - """确保内置工具已初始化""" - existing = self.tool_repo.exists_builtin_for_tenant(self.db, tenant_id) - - if existing: + """确保内置工具已初始化(支持增量补充新工具)""" + builtin_config = self._load_builtin_config() + if not builtin_config: return - # 从配置文件加载内置工具定义 - builtin_config = self._load_builtin_config() + existing_classes = self.builtin_repo.get_existing_tool_classes(self.db, tenant_id) + added = False for tool_key, tool_info in builtin_config.items(): + if tool_info['tool_class'] in existing_classes: + continue try: - # 创建工具配置 initial_status = self._determine_initial_status(tool_info) tool_config = ToolConfig( name=tool_info['name'], @@ -424,7 +368,6 @@ class ToolService: self.db.add(tool_config) self.db.flush() - # 创建内置工具配置 builtin_config_obj = BuiltinToolConfig( id=tool_config.id, tool_class=tool_info['tool_class'], @@ -432,12 +375,14 @@ class ToolService: requires_config=tool_info.get('requires_config', False) ) self.db.add(builtin_config_obj) + added = True except Exception as e: logger.error(f"初始化内置工具失败: {tool_key}, {e}") - self.db.commit() - logger.info(f"租户 {tenant_id} 内置工具初始化完成") + if added: + self.db.commit() + logger.info(f"租户 {tenant_id} 内置工具增量初始化完成") async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]: """获取工具的所有方法 @@ -515,6 +460,9 @@ class ToolService: # 对于json_tool,根据操作类型返回相关参数 elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool': return self._get_json_tool_params(operation) + # 对于openclaw_tool,根据操作类型返回不同描述的参数 + elif hasattr(tool_instance, 'name') and tool_instance.name == 'openclaw_tool': + return self._get_openclaw_tool_params(operation) # 其他工具的默认处理:返回除operation外的所有参数 return [{ @@ -744,6 +692,65 @@ class ToolService: return base_params + @staticmethod + def _get_openclaw_tool_params(operation: str) -> List[Dict[str, Any]]: + """获取 openclaw_tool 特定操作的参数""" + if operation == "print_task": + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文", + "required": True + }, + { + "name": "image_url", + "type": "string", + "description": "可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型", + "required": False + } + ] + elif operation == "device_query": + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw 的设备查询指令", + "required": True + } + ] + elif operation == "image_understand": + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)", + "required": True + }, + { + "name": "image_url", + "type": "string", + "description": "必须提供,要分析的图片 URL 或 base64 data URI", + "required": True + } + ] + else: + # general 及其他 + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求", + "required": True + }, + { + "name": "image_url", + "type": "string", + "description": "可选,附带的图片 URL 或 base64 data URI", + "required": False + } + ] + async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: """获取自定义工具的方法""" custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) @@ -1196,27 +1203,6 @@ class ToolService: custom_config = self.db.query(CustomToolConfig).filter( CustomToolConfig.id == tool_config.id ).first() - # ========== 更新工具 OpenClaw 特判 ========== - if custom_config and custom_config.schema_content: - schema = custom_config.schema_content - if isinstance(schema, str): - try: - schema = json.loads(schema) - except json.JSONDecodeError: - schema = {} - info = schema.get("info", {}) if isinstance(schema, dict) else {} - if info.get("x-openclaw"): - servers = schema.get("servers", []) - has_url = bool(servers and servers[0].get("url")) - has_agent_id = bool(info.get("x-openclaw-agent-id")) - has_token = bool(custom_config.auth_config - and custom_config.auth_config.get("api_key")) - if has_url and has_agent_id and has_token: - tool_config.status = ToolStatus.AVAILABLE.value - else: - tool_config.status = ToolStatus.UNCONFIGURED.value - return - # ========== OpenClaw 特判结束 ========== if custom_config and tool_config.name and (custom_config.schema_content or custom_config.schema_url): tool_config.status = ToolStatus.AVAILABLE.value From 5adff38bda98307ee04fc26f1eb0ed6194dac3d8 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 9 Apr 2026 18:58:21 +0800 Subject: [PATCH 13/85] feat(web): workflow check list --- web/src/assets/images/workflow/checkList.svg | 16 + web/src/assets/images/workflow/features.svg | 10 +- web/src/i18n/en.ts | 77 ++++- web/src/i18n/zh.ts | 37 +++ .../components/ConfigHeader.tsx | 4 +- web/src/views/ApplicationConfig/types.ts | 2 + .../Workflow/components/CheckList/index.tsx | 285 ++++++++++++++++++ .../Properties/CodeExecution/OutputList.tsx | 3 +- .../Properties/ConditionList/index.tsx | 2 +- .../Properties/HttpRequest/index.tsx | 6 +- .../Properties/Knowledge/Knowledge.tsx | 2 +- .../Properties/ModelConfig/index.tsx | 1 + .../Properties/ParamsList/index.tsx | 2 +- .../components/Properties/VariableSelect.tsx | 2 +- .../Workflow/components/Properties/index.tsx | 3 +- web/src/views/Workflow/constant.ts | 35 ++- .../views/Workflow/hooks/useWorkflowGraph.ts | 3 +- web/src/views/Workflow/index.tsx | 4 +- web/src/views/Workflow/types.ts | 1 + 19 files changed, 475 insertions(+), 20 deletions(-) create mode 100644 web/src/assets/images/workflow/checkList.svg create mode 100644 web/src/views/Workflow/components/CheckList/index.tsx diff --git a/web/src/assets/images/workflow/checkList.svg b/web/src/assets/images/workflow/checkList.svg new file mode 100644 index 00000000..169743dc --- /dev/null +++ b/web/src/assets/images/workflow/checkList.svg @@ -0,0 +1,16 @@ + + + 参与 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/features.svg b/web/src/assets/images/workflow/features.svg index 2ff48584..bd31b107 100644 --- a/web/src/assets/images/workflow/features.svg +++ b/web/src/assets/images/workflow/features.svg @@ -1,12 +1,14 @@ 参与 - - + + - - + + + + diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index d54f0d25..8f88b561 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1396,6 +1396,43 @@ export const en = { pleaseUploadFile: 'Please upload file', setting: 'Settings', features: 'Conversation Features', + checkList: 'Check List', + checkListDesc: 'Ensure all issues are resolved before publishing', + checkListEmpty: 'No issues found', + notConnected: 'This node is not connected to other nodes', + goto: 'Go to', + cannotBeEmpty: 'cannot be empty', + checkListErrors: { + 'llm.model_id': 'Model', + 'llm.messages': 'Messages', + 'end.output': 'Output', + 'knowledge-retrieval.knowledge_retrieval': 'Knowledge bases', + 'parameter-extractor.model_id': 'Model', + 'parameter-extractor.text': 'Input variable', + 'parameter-extractor.params': 'Params', + 'memory-read.message': 'Message', + 'memory-read.config_id': 'Memory config', + 'memory-read.search_switch': 'Search mode', + 'memory-write.messages': 'Messages', + 'memory-write.config_id': 'Memory config', + 'if-else.cases': 'Condition', + 'question-classifier.model_id': 'Model', + 'question-classifier.input_variable': 'Input variable', + 'question-classifier.categories': 'Categories', + 'iteration.input': 'Input variable', + 'iteration.output': 'Output variable', + 'var-aggregator.group_variables': 'Variables', + 'assigner.assignments': 'Variables', + 'http-request.url': 'API URL', + 'http-request.body.data': 'Binary file variable', + 'code.input_variables': 'Input variables', + 'code.code': 'Code', + 'code.output_variables': 'Output variables', + 'jinja-render.mapping': 'Input variables', + 'jinja-render.template': 'Template', + 'document-extractor.file_selector': 'File variable', + 'list-operator.input_list': 'Input list', + }, file_upload: 'File Upload', file_upload_desc: 'The chat input box supports file uploads. Types include images, documents, and other types', settings: 'File Upload Settings', @@ -2442,7 +2479,45 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re iteration: 'Iteration', input_cycle_vars: 'Initial Loop Variables', output_cycle_vars: 'Final Loop Variables', - } + }, + sureReplace: '确认替换', + checkList: 'Check List', + checkListDesc: 'Ensure all issues are resolved before publishing', + checkListEmpty: 'No issues found', + notConnected: 'This node is not connected to other nodes', + goto: 'Go to', + cannotBeEmpty: 'cannot be empty', + checkListErrors: { + 'llm.model_id': 'Model', + 'llm.messages': 'Messages', + 'end.output': 'Output', + 'knowledge-retrieval.knowledge_retrieval': 'Knowledge bases', + 'parameter-extractor.model_id': 'Model', + 'parameter-extractor.text': 'Input variable', + 'parameter-extractor.params': 'Params', + 'memory-read.message': 'Message', + 'memory-read.config_id': 'Memory config', + 'memory-read.search_switch': 'Search mode', + 'memory-write.messages': 'Messages', + 'memory-write.config_id': 'Memory config', + 'if-else.cases': 'Condition', + 'question-classifier.model_id': 'Model', + 'question-classifier.input_variable': 'Input variable', + 'question-classifier.categories': 'Categories', + 'iteration.input': 'Input variable', + 'iteration.output': 'Output variable', + 'var-aggregator.group_variables': 'Variables', + 'assigner.assignments': 'Variables', + 'http-request.url': 'API URL', + 'http-request.body.data': 'Binary file variable', + 'code.input_variables': 'Input variables', + 'code.code': 'Code', + 'code.output_variables': 'Output variables', + 'jinja-render.mapping': 'Input variables', + 'jinja-render.template': 'Template', + 'document-extractor.file_selector': 'File variable', + 'list-operator.input_list': 'Input list', + }, }, emotionEngine: { emotionEngineConfig: 'Emotion Engine Configuration', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 5b46cb48..e521d5fa 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -2445,6 +2445,43 @@ export const zh = { output_cycle_vars: '最终循环变量', }, sureReplace: '确认替换', + checkList: '检查清单', + checkListDesc: '发布前确保所有问题均已解决', + checkListEmpty: '没有发现问题', + notConnected: '此节点尚未连接到其他节点', + goto: '转到', + cannotBeEmpty: '不能为空', + checkListErrors: { + 'llm.model_id': '模型', + 'llm.messages': '提示词', + 'end.output': '回复', + 'knowledge-retrieval.knowledge_retrieval': '知识库', + 'parameter-extractor.model_id': '模型', + 'parameter-extractor.text': '输入变量', + 'parameter-extractor.params': '提取参数', + 'memory-read.message': '消息', + 'memory-read.config_id': '记忆配置', + 'memory-read.search_switch': '检索模式', + 'memory-write.messages': '消息', + 'memory-write.config_id': '记忆配置', + 'if-else.cases': '条件', + 'question-classifier.model_id': '模型', + 'question-classifier.input_variable': '输入变量', + 'question-classifier.categories': '分类', + 'iteration.input': '输入变量', + 'iteration.output': '输出变量', + 'var-aggregator.group_variables': '变量', + 'assigner.assignments': '变量', + 'http-request.url': 'API URL', + 'http-request.body.data': 'binary文件类型变量', + 'code.input_variables': '输入变量', + 'code.code': '代码', + 'code.output_variables': '输出变量', + 'jinja-render.mapping': '输入变量', + 'jinja-render.template': '模板', + 'document-extractor.file_selector': '文件变量', + 'list-operator.input_list': '输入变量', + }, }, emotionEngine: { emotionEngineConfig: '情感引擎配置', diff --git a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx index d77ae27c..b2e1e36b 100644 --- a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx +++ b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx @@ -4,7 +4,7 @@ * @Last Modified by: ZhaoYing * @Last Modified time: 2026-04-07 16:28:33 */ -import { type FC, useRef, useMemo, useCallback } from 'react'; +import { type FC, useRef, useMemo } from 'react'; import { useNavigate, useParams } from 'react-router-dom'; import { Tabs, Dropdown, Flex, Popover } from 'antd'; import type { MenuProps } from 'antd'; @@ -18,6 +18,7 @@ import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef, FeaturesConfigFor import { deleteApplication, appExport } from '@/api/application' import CopyModal from './CopyModal' import PageHeader from '@/components/Layout/PageHeader' +import CheckList from '@/views/Workflow/components/CheckList' /** * Tab keys for application configuration @@ -206,6 +207,7 @@ const ConfigHeader: FC = ({ } extra={application?.type === 'workflow' && source !== 'sharing' && activeTab === 'arrangement' ? +
void; handleSaveFeaturesConfig?: (value: FeaturesConfigForm) => void; + nodeClick: ({ node }: { node: Node }) => void; } /** diff --git a/web/src/views/Workflow/components/CheckList/index.tsx b/web/src/views/Workflow/components/CheckList/index.tsx new file mode 100644 index 00000000..fe627e03 --- /dev/null +++ b/web/src/views/Workflow/components/CheckList/index.tsx @@ -0,0 +1,285 @@ +import { type FC, useState, useCallback, useEffect, useRef } from 'react' +import { Popover, Flex } from 'antd' +import { WarningFilled } from '@ant-design/icons' +import { useTranslation } from 'react-i18next' +import { Node } from '@antv/x6'; + +import type { WorkflowRef } from '@/views/ApplicationConfig/types' +import { nodeLibrary } from '../../constant' +import { getToolMethods } from '@/api/tools' +import RbDrawer from '@/components/RbDrawer' + +interface CheckListProps { + workflowRef: React.RefObject +} + +interface CheckError { + key: string + message: string +} + +interface NodeCheckResult { + id: string + name: string + type: string + icon: string + errors: CheckError[] +} + +const allNodes = nodeLibrary.flatMap(c => c.nodes) +const nodeIconMap: Record = Object.fromEntries(allNodes.map(n => [n.type, n.icon])) +const nodeConfigMap: Record> = Object.fromEntries( + allNodes.filter(n => n.config).map(n => [n.type, n.config!]) +) + +// Special validators for fields that need deeper checks beyond simple empty check +const specialValidators: Record boolean> = { + // llm.messages: at least one message with non-empty content + 'llm.messages': (val: any[]) => !Array.isArray(val) || !val.some(m => m?.content && String(m.content).trim()), + // knowledge-retrieval.knowledge_retrieval: knowledge_bases array must be non-empty + 'knowledge-retrieval.knowledge_retrieval': (val: any) => !(val?.knowledge_bases?.length > 0), + 'memory-write.messages': (val: any[]) => !Array.isArray(val) || !val.some(m => m?.content && String(m.content).trim()), + // if-else.cases: every case must have at least one expression, and every expression must be fully set + 'if-else.cases': (val: any[]) => { + if (!Array.isArray(val) || !val.length) return true + return val.some(c => { + if (!c?.expressions?.length) return true + return c.expressions.some((expr: any) => { + if (!expr?.left) return true + if (['not_empty', 'empty'].includes(expr.operator)) return false + return !(!!expr.left && (!!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number')) + }) + }) + }, + // question-classifier.categories: every category must have a value + 'question-classifier.categories': (val: any[]) => !Array.isArray(val) || !val.some(c => c?.class_name && String(c.class_name).trim()), + // var-aggregator.group_variables: must be non-empty array + 'var-aggregator.group_variables': (val: any[]) => !Array.isArray(val) || !val.length, + // assigner.assignments: every item needs variable_selector + operation; value required unless operation is 'clear' + 'assigner.assignments': (val: any[]) => { + if (!Array.isArray(val) || !val.length) return false + return val.some(a => { + if (!a?.variable_selector || !a?.operation) return true + if (a.operation === 'clear') return false + return a.value === undefined || a.value === null || a.value === '' + }) + }, + // http-request.body: binary content_type requires data + 'http-request.body': (val: any) => val?.content_type === 'binary' && !val?.data, + // tool.tool_parameters: validated async via API, placeholder always returns false + 'tool.tool_parameters': () => false, + // code.input_variables: if non-empty, every item must have both name and variable + 'code.input_variables': (val: any[]) => Array.isArray(val) && val.length > 0 && val.some(v => !v?.name || !v?.variable), + // code.output_variables: must be non-empty + 'code.output_variables': (val: any[]) => !Array.isArray(val) || !val.length, + // jinja-render.mapping: if non-empty, every item must have a name + 'jinja-render.mapping': (val: any[]) => Array.isArray(val) && val.length > 0 && val.some(v => !v?.name || !v?.value), +} + +function isEmpty(val: any): boolean { + console.log('validateNode isEmpty', val, val === undefined || val === null || val === '') + if (val === undefined || val === null || val === '') return true + if (Array.isArray(val)) return val.length === 0 + return false +} + +function validateNode(type: string, config: Record): CheckError[] { + const errors: CheckError[] = [] + const nodeConfig = nodeConfigMap[type] + if (!nodeConfig) return errors + + const get = (key: string) => config[key]?.defaultValue + + Object.entries(nodeConfig).forEach(([field, fieldConfig]) => { + if (!fieldConfig?.required) return + const val = get(field) + const specialKey = `${type}.${field}` + const specialValidator = specialValidators[specialKey] + const isInvalid = specialValidator ? specialValidator(val) : isEmpty(val) + console.log('validateNode', val, specialKey, specialValidator, isEmpty(val)) + if (isInvalid) errors.push({ key: specialKey, message: '' }) + }) + + // http-request body.data (binary) — not a top-level required field, check separately + if (type === 'http-request') { + const body = get('body') + if (body?.content_type === 'binary' && !body?.data) { + errors.push({ key: 'http-request.body.data', message: '' }) + } + } + + // console.log('nodeConfig', nodeConfigMap, nodeConfig, errors) + return errors +} + +const CheckList: FC = ({ workflowRef }) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [results, setResults] = useState([]) + const timerRef = useRef>() + + const runCheck = useCallback(async () => { + const graph = workflowRef.current?.graphRef?.current + if (!graph) return [] + + const nodes = graph.getNodes() + const edges = graph.getEdges() + const sourceIds = new Set() + const targetIds = new Set() + // child-to-child edges within same parent (cycle) + const childTargetIds = new Set() + edges.forEach(e => { + sourceIds.add(e.getSourceCellId()) + targetIds.add(e.getTargetCellId()) + const srcData = graph.getCellById(e.getSourceCellId())?.getData() + const tgtData = graph.getCellById(e.getTargetCellId())?.getData() + if (srcData?.cycle && tgtData?.cycle && srcData.cycle === tgtData.cycle) { + childTargetIds.add(e.getTargetCellId()) + } + }) + + const checked: NodeCheckResult[] = [] + for (const node of nodes) { + const data = node.getData() + if (!data || ['add-node', 'notes', 'cycle-start', 'break'].includes(data.type)) continue + + const errors: CheckError[] = [] + + + // Check connectivity + const isChildNode = !!data.cycle + const hasIncoming = isChildNode ? childTargetIds.has(node.id) : !['start', 'cycle-start'].includes(data.type) ? targetIds.has(node.id) : true + if (!hasIncoming) { + errors.push({ key: 'notConnected', message: t('workflow.notConnected') }) + } + + // Validate config + const configErrors = validateNode(data.type, data.config ?? {}) + configErrors.forEach(e => { + errors.push({ key: e.key, message: `${t(`workflow.checkListErrors.${e.key}`)} ${t('workflow.cannotBeEmpty')}`.trim() }) + }) + + // Tool node: fetch parameters via API and check required fields + if (data.type === 'tool') { + const toolId = data.config?.tool_id?.defaultValue ?? data.config?.tool_id + const toolParameters = data.config?.tool_parameters?.defaultValue ?? data.config?.tool_parameters ?? {} + if (toolId) { + try { + const methods = await getToolMethods(toolId) as Array<{ name: string; parameters: Array<{ name: string; required: boolean }> }> + const operation = toolParameters?.operation + const method = operation ? methods.find(m => m.name === operation) : methods[0] + if (method) { + const missingParams = method.parameters.filter(p => p.required && (toolParameters[p.name] === undefined || toolParameters[p.name] === null || toolParameters[p.name] === '')) + missingParams.forEach(p => errors.push({ key: 'tool.tool_parameters', message: `${p.name} ${t('workflow.cannotBeEmpty')}` })) + } + } catch { + // ignore API errors + } + } + } + + if (errors.length) { + checked.push({ + id: node.id, + name: data.name || t(`workflow.${data.type}`), + type: data.type, + icon: nodeIconMap[data.type] ?? '', + errors, + }) + } + } + + return checked + }, [workflowRef.current?.graphRef?.current, t]) + + const scheduleCheck = useCallback(() => { + clearTimeout(timerRef.current) + timerRef.current = setTimeout(async () => { + setResults(await runCheck()) + }, 500) + }, [runCheck]) + + useEffect(() => { + const graph = workflowRef.current?.graphRef?.current + if (!graph) return + const events = ['node:added', 'node:removed', 'node:change:data', 'edge:added', 'edge:removed'] + events.forEach(e => graph.on(e, scheduleCheck)) + scheduleCheck() + return () => { + events.forEach(e => graph.off(e, scheduleCheck)) + clearTimeout(timerRef.current) + } + }, [workflowRef.current?.graphRef?.current]) + + const handleOpen = () => { + setOpen(true) + } + + const focusNode = (id: string) => { + const graph = workflowRef.current?.graphRef?.current + if (!graph) return + const node = graph.getCellById(id) + if (node) { + workflowRef.current?.nodeClick({node} as { node: Node }) + } + setOpen(false) + } + + return ( + <> + +
+
+ {results.length > 0 && ( + + {results.reduce((sum, n) => sum + n.errors.length, 0)} + + )} +
+ + + {t('workflow.checkList')}{results.length > 0 ? `(${results.reduce((sum, n) => sum + n.errors.length, 0)})` : ''} + + } + open={open} + onClose={() => setOpen(false)} + width={360} + styles={{ body: { padding: '12px 16px' } }} + > +

{t('workflow.checkListDesc')}

+ {results.length === 0 + ?
{t('workflow.checkListEmpty')}
+ : + {results.map(node => ( +
+ +
+ {node.name} + focusNode(node.id)} + > + {t('workflow.goto')} → + + + + + {node.errors.map((err, i) => ( + + + {err.message} + + ))} + +
+ ))} +
+ } + + + ) +} + +export default CheckList diff --git a/web/src/views/Workflow/components/Properties/CodeExecution/OutputList.tsx b/web/src/views/Workflow/components/Properties/CodeExecution/OutputList.tsx index 6bb12d9b..0080c493 100644 --- a/web/src/views/Workflow/components/Properties/CodeExecution/OutputList.tsx +++ b/web/src/views/Workflow/components/Properties/CodeExecution/OutputList.tsx @@ -27,7 +27,8 @@ const OutputList: FC = ({ label, name, extra }) => { <>
- {label} + + *{label}
diff --git a/web/src/views/Workflow/components/Properties/ConditionList/index.tsx b/web/src/views/Workflow/components/Properties/ConditionList/index.tsx index d484da09..cad32e37 100644 --- a/web/src/views/Workflow/components/Properties/ConditionList/index.tsx +++ b/web/src/views/Workflow/components/Properties/ConditionList/index.tsx @@ -58,7 +58,7 @@ const ConditionList: FC = ({ const { t } = useTranslation(); const form = Form.useFormInstance(); - const handleLeftFieldChange = (index: number, newValue: string) => { + const handleLeftFieldChange = (index: number, newValue?: string | string[]) => { form.setFieldsValue({ [parentName]: { expressions: { diff --git a/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx b/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx index 53714327..a02549bd 100644 --- a/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx +++ b/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx @@ -87,7 +87,9 @@ const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: an return ( <> -
API
+
+ *API +