diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 60c22855..807c59f4 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -62,10 +62,10 @@ celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', - - # 时区 - timezone='Asia/Shanghai', - enable_utc=False, + + # # 时区 + # timezone='Asia/Shanghai', + # enable_utc=False, # 任务追踪 task_track_started=True, @@ -116,6 +116,7 @@ celery_app.conf.update( 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, 'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'}, 'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'}, + 'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'}, }, ) diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 5b71190d..0f2da3b0 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -55,6 +55,12 @@ async def get_mcp_servers( status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0" ) + if page * pagesize > 100: + api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza." + ) # 2. Query mcp market config information from the database api_logger.debug(f"Query mcp market config: {mcp_market_config_id}") @@ -64,14 +70,16 @@ async def get_mcp_servers( if not db_mcp_market_config: api_logger.warning( f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or access is denied" - ) + return success(msg='The mcp market config does not exist or access is denied') # 3. Execute paged query - api = MCPApi() token = db_mcp_market_config.token + if not token: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="MCP market config token is not configured" + ) + api = MCPApi() api.login(token) body = { @@ -151,14 +159,16 @@ async def get_operational_mcp_servers( if not db_mcp_market_config: api_logger.warning( f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or access is denied" - ) + return success(msg='The mcp market config does not exist or access is denied') # 2. Execute paged query - api = MCPApi() token = db_mcp_market_config.token + if not token: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="MCP market config token is not configured" + ) + api = MCPApi() api.login(token) url = f'{api.mcp_base_url}/operational' @@ -209,14 +219,16 @@ async def get_mcp_server( if not db_mcp_market_config: api_logger.warning( f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or access is denied" - ) + return success(msg='The mcp market config does not exist or access is denied') # 2. Get detailed information for a specific MCP Server - api = MCPApi() token = db_mcp_market_config.token + if not token: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="MCP market config token is not configured" + ) + api = MCPApi() api.login(token) result = api.get_mcp_server(server_id=server_id) @@ -237,7 +249,26 @@ async def create_mcp_market_config( try: api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}") - # 1. Check if the mcp market name already exists + # 1. Validate token can access ModelScope MCP market + if not create_data.token: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Token is required to access ModelScope MCP market" + ) + try: + api = MCPApi() + api.login(create_data.token) + body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} + cookies = api.get_cookies(create_data.token) + r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + raise_for_http_status(r) + except Exception as e: + api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}" + ) + # 2. Check if the mcp market name already exists db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user) if db_mcp_market_config_exist: api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}") @@ -245,6 +276,30 @@ async def create_mcp_market_config( status_code=status.HTTP_400_BAD_REQUEST, detail=f"The mcp market id already exists: {create_data.mcp_market_id}" ) + # 2. verify token + create_data.status = 1 + try: + api = MCPApi() + token = create_data.token + api.login(token) + + body = { + 'filter': {}, + 'page_number': 1, + 'page_size': 20, + 'search': "" + } + cookies = api.get_cookies(token) + r = api.session.put( + url=api.mcp_base_url, + headers=api.builder_headers(api.headers), + json=body, + cookies=cookies) + raise_for_http_status(r) + except requests.exceptions.RequestException as e: + api_logger.error(f"Failed to get MCP servers: {str(e)}") + create_data.status = 0 + # 3. create mcp_market_config db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user) api_logger.info( f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})") @@ -273,10 +328,7 @@ async def get_mcp_market_config( db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user) if not db_mcp_market_config: api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or access is denied" - ) + return success(msg='The mcp market config does not exist or access is denied') api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})") return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)), @@ -306,10 +358,7 @@ async def get_mcp_market_config_by_mcp_market_id( db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user) if not db_mcp_market_config: api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or access is denied" - ) + return success(msg='The mcp market config does not exist or access is denied') api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})") return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)), @@ -335,12 +384,25 @@ async def update_mcp_market_config( if not db_mcp_market_config: api_logger.warning( f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or you do not have permission to access it" - ) + return success(msg='The mcp market config does not exist or access is denied') - # 2. Update fields (only update non-null fields) + # 2. Validate new token if provided + if update_data.token is not None: + try: + api = MCPApi() + api.login(update_data.token) + body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} + cookies = api.get_cookies(update_data.token) + r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + raise_for_http_status(r) + except Exception as e: + api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}" + ) + + # 3. Update fields (only update non-null fields) api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}") update_dict = update_data.dict(exclude_unset=True) updated_fields = [] @@ -355,7 +417,7 @@ async def update_mcp_market_config( if updated_fields: api_logger.debug(f"updated fields: {', '.join(updated_fields)}") - # 3. Save to database + # 4. Save to database try: db.commit() db.refresh(db_mcp_market_config) @@ -368,7 +430,7 @@ async def update_mcp_market_config( detail=f"The mcp market config update failed: {str(e)}" ) - # 4. Return the updated mcp market config + # 5. Return the updated mcp market config return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)), msg="The mcp market config information updated successfully") @@ -392,10 +454,7 @@ async def delete_mcp_market_config( if not db_mcp_market_config: api_logger.warning( f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The mcp market config does not exist or you do not have permission to access it" - ) + return success(msg='The mcp market config does not exist or access is denied') # 2. Deleting mcp market config mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 3bbb5cf7..22fd2c6c 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -193,7 +193,16 @@ async def get_workspace_end_users( await aio_redis_set(cache_key, json.dumps(result), expire=30) except Exception as e: api_logger.warning(f"Redis 缓存写入失败: {str(e)}") - + + # 触发社区聚类补全任务(异步,不阻塞接口响应) + # 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类 + try: + from app.tasks import init_community_clustering_for_users + init_community_clustering_for_users.delay(end_user_ids=end_user_ids) + api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}") + except Exception as e: + api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") + api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return success(data=result, msg="宿主列表获取成功") @@ -403,14 +412,15 @@ def get_current_user_rag_total_num( @router.get("/rag_content", response_model=ApiResponse) def get_rag_content( end_user_id: str = Query(..., description="宿主ID"), - limit: int = Query(15, description="返回记录数"), + page: int = Query(1, gt=0, description="页码,从1开始"), + pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """ - 获取当前宿主知识库中的chunk内容 + 获取当前宿主知识库中的chunk内容(分页) """ - data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user) + data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user) return success(data=data, msg="宿主RAGchunk数据获取成功") diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index d3fe7d83..be796ff9 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -17,6 +17,7 @@ from app.services.user_memory_service import ( UserMemoryService, analytics_memory_types, analytics_graph_data, + analytics_community_graph_data, ) from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.schemas.response_schema import ApiResponse @@ -295,6 +296,42 @@ async def get_graph_data_api( return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e)) +@router.get("/analytics/community_graph", response_model=ApiResponse) +async def get_community_graph_data_api( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + result = await analytics_community_graph_data(db=db, end_user_id=end_user_id) + + if "message" in result and result["statistics"]["total_nodes"] == 0: + api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}") + return success(data=result, msg=result.get("message", "查询成功")) + + api_logger.info( + f"成功获取社区图谱: end_user_id={end_user_id}, " + f"nodes={result['statistics']['total_nodes']}, " + f"edges={result['statistics']['total_edges']}" + ) + return success(data=result, msg="查询成功") + + except Exception as e: + api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e)) + + @router.get("/read_end_user/profile", response_model=ApiResponse) async def get_end_user_profile( end_user_id: str, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py index ca08db76..5241ac89 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -1,4 +1,5 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState +from app.schemas.memory_agent_schema import AgentMemoryDataset def content_input_node(state: ReadState) -> ReadState: @@ -17,6 +18,9 @@ def content_input_node(state: ReadState) -> ReadState: content = state['messages'][0].content if state.get('messages') else '' # Return content and maintain all state information + for pronoun in AgentMemoryDataset.PRONOUN: + content = content.replace(pronoun, AgentMemoryDataset.NAME) + return {"data": content} @@ -35,4 +39,7 @@ def content_input_write(state: WriteState) -> WriteState: content = state['messages'][0].content if state.get('messages') else '' # Return content and maintain all state information + for pronoun in AgentMemoryDataset.PRONOUN: + content = content.replace(pronoun, AgentMemoryDataset.NAME) + return {"data": content} diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 22030278..b3707083 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -165,7 +165,9 @@ async def write( statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, - connector=neo4j_connector + connector=neo4j_connector, + config_id=config_id, + llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, ) if success: logger.info("Successfully saved all data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/__init__.py b/api/app/core/memory/storage_services/clustering_engine/__init__.py new file mode 100644 index 00000000..992d8bff --- /dev/null +++ b/api/app/core/memory/storage_services/clustering_engine/__init__.py @@ -0,0 +1,3 @@ +from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + +__all__ = ["LabelPropagationEngine"] diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py new file mode 100644 index 00000000..cbc303b1 --- /dev/null +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -0,0 +1,484 @@ +"""标签传播聚类引擎 + +基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。 + +支持两种模式: +- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代 +- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居 +""" + +import logging +import uuid +from math import sqrt +from typing import Dict, List, Optional + +from app.repositories.neo4j.community_repository import CommunityRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = logging.getLogger(__name__) + +# 全量迭代最大轮数,防止不收敛 +MAX_ITERATIONS = 10 +# 社区摘要核心实体数量 +CORE_ENTITY_LIMIT = 5 + + +def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: + """计算两个向量的余弦相似度,任一为空则返回 0。""" + if not v1 or not v2 or len(v1) != len(v2): + return 0.0 + dot = sum(a * b for a, b in zip(v1, v2)) + norm1 = sqrt(sum(a * a for a in v1)) + norm2 = sqrt(sum(b * b for b in v2)) + if norm1 == 0 or norm2 == 0: + return 0.0 + return dot / (norm1 * norm2) + + +def _weighted_vote( + neighbors: List[Dict], + self_embedding: Optional[List[float]], +) -> Optional[str]: + """ + 加权多数投票,选出得票最高的社区。 + + 权重 = 语义相似度(name_embedding 余弦)* activation_value 加成 + 没有 community_id 的邻居不参与投票。 + """ + votes: Dict[str, float] = {} + for nb in neighbors: + cid = nb.get("community_id") + if not cid: + continue + sem = _cosine_similarity(self_embedding, nb.get("name_embedding")) + act = nb.get("activation_value") or 0.5 + # 语义相似度权重 0.6,激活值权重 0.4 + weight = 0.6 * sem + 0.4 * act + votes[cid] = votes.get(cid, 0.0) + weight + + if not votes: + return None + return max(votes, key=votes.__getitem__) + + +class LabelPropagationEngine: + """标签传播聚类引擎""" + + def __init__( + self, + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + ): + self.connector = connector + self.repo = CommunityRepository(connector) + self.config_id = config_id + self.llm_model_id = llm_model_id + + # ────────────────────────────────────────────────────────────────────────── + # 公开接口 + # ────────────────────────────────────────────────────────────────────────── + + async def run( + self, + end_user_id: str, + new_entity_ids: Optional[List[str]] = None, + ) -> None: + """ + 统一入口:自动判断全量还是增量。 + + - 若该用户尚无 Community 节点 → 全量初始化 + - 否则 → 增量更新(仅处理 new_entity_ids) + """ + has_communities = await self.repo.has_communities(end_user_id) + if not has_communities: + logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化") + await self.full_clustering(end_user_id) + else: + if new_entity_ids: + logger.info( + f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}" + ) + await self.incremental_update(new_entity_ids, end_user_id) + + async def full_clustering(self, end_user_id: str) -> None: + """ + 全量标签传播初始化。 + + 1. 拉取所有实体,初始化每个实体为独立社区 + 2. 迭代:每轮对所有实体做邻居投票,更新社区标签 + 3. 直到标签不再变化或达到 MAX_ITERATIONS + 4. 将最终标签写入 Neo4j + """ + entities = await self.repo.get_all_entities(end_user_id) + if not entities: + logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") + return + + # 初始化:每个实体持有自己 id 作为社区标签 + labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} + embeddings: Dict[str, Optional[List[float]]] = { + e["id"]: e.get("name_embedding") for e in entities + } + + # 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返 + logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...") + neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id) + logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") + + for iteration in range(MAX_ITERATIONS): + changed = 0 + # 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历) + for entity in entities: + eid = entity["id"] + # 直接从缓存取邻居,不再发起 Neo4j 查询 + neighbors = neighbors_cache.get(eid, []) + + # 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值) + enriched = [] + for nb in neighbors: + nb_copy = dict(nb) + nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) + enriched.append(nb_copy) + + new_label = _weighted_vote(enriched, embeddings.get(eid)) + if new_label and new_label != labels[eid]: + labels[eid] = new_label + changed += 1 + + logger.info( + f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}," + f"标签变化数: {changed}" + ) + if changed == 0: + logger.info("[Clustering] 标签已收敛,提前结束迭代") + break + + # 将最终标签写入 Neo4j + await self._flush_labels(labels, end_user_id) + pre_merge_count = len(set(labels.values())) + logger.info( + f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区," + f"{len(labels)} 个实体,开始后处理合并" + ) + + # 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度) + all_community_ids = list(set(labels.values())) + await self._evaluate_merge(all_community_ids, end_user_id) + + logger.info( + f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," + f"{len(labels)} 个实体" + ) + # 为所有社区生成元数据 + # 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区 + # 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID + surviving_communities = await self.repo.get_all_entities(end_user_id) + surviving_community_ids = list({ + e.get("community_id") for e in surviving_communities + if e.get("community_id") + }) + logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") + for cid in surviving_community_ids: + await self._generate_community_metadata(cid, end_user_id) + + async def incremental_update( + self, new_entity_ids: List[str], end_user_id: str + ) -> None: + """ + 增量更新:只处理新实体及其邻居,不重跑全图。 + + 1. 对每个新实体查询邻居 + 2. 加权多数投票决定社区归属 + 3. 若邻居无社区 → 创建新社区 + 4. 若邻居分属多个社区 → 评估是否合并 + """ + for entity_id in new_entity_ids: + await self._process_single_entity(entity_id, end_user_id) + + # ────────────────────────────────────────────────────────────────────────── + # 内部方法 + # ────────────────────────────────────────────────────────────────────────── + + async def _process_single_entity( + self, entity_id: str, end_user_id: str + ) -> None: + """处理单个新实体的社区分配。""" + neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id) + + # 查询自身 embedding(从邻居查询结果中无法获取,需单独查) + self_embedding = await self._get_entity_embedding(entity_id, end_user_id) + + if not neighbors: + # 孤立实体:创建单成员社区 + new_cid = self._new_community_id() + await self.repo.upsert_community(new_cid, end_user_id, member_count=1) + await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) + logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") + return + + # 统计邻居社区分布 + community_ids_in_neighbors = set( + nb["community_id"] for nb in neighbors if nb.get("community_id") + ) + + target_cid = _weighted_vote(neighbors, self_embedding) + + if target_cid is None: + # 邻居都没有社区,连同新实体一起创建新社区 + new_cid = self._new_community_id() + await self.repo.upsert_community(new_cid, end_user_id) + await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) + for nb in neighbors: + await self.repo.assign_entity_to_community( + nb["id"], new_cid, end_user_id + ) + await self.repo.refresh_member_count(new_cid, end_user_id) + logger.debug( + f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" + ) + await self._generate_community_metadata(new_cid, end_user_id) + else: + # 加入得票最多的社区 + await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) + await self.repo.refresh_member_count(target_cid, end_user_id) + logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}") + + # 若邻居分属多个社区,评估合并 + if len(community_ids_in_neighbors) > 1: + await self._evaluate_merge( + list(community_ids_in_neighbors), end_user_id + ) + await self._generate_community_metadata(target_cid, end_user_id) + + async def _evaluate_merge( + self, community_ids: List[str], end_user_id: str + ) -> None: + """ + 评估多个社区是否应合并。 + + 策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。 + 合并时保留成员数最多的社区,其余成员迁移过来。 + + 全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。 + """ + MERGE_THRESHOLD = 0.85 + BATCH_THRESHOLD = 20 # 超过此数量走批量查询 + + community_embeddings: Dict[str, Optional[List[float]]] = {} + community_sizes: Dict[str, int] = {} + + if len(community_ids) > BATCH_THRESHOLD: + # 批量查询:一次拉取所有社区成员 + all_members = await self.repo.get_all_community_members_batch( + community_ids, end_user_id + ) + for cid in community_ids: + members = all_members.get(cid, []) + community_sizes[cid] = len(members) + valid_embeddings = [ + m["name_embedding"] for m in members if m.get("name_embedding") + ] + if valid_embeddings: + dim = len(valid_embeddings[0]) + community_embeddings[cid] = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + else: + community_embeddings[cid] = None + else: + # 增量场景:逐个查询 + for cid in community_ids: + members = await self.repo.get_community_members(cid, end_user_id) + community_sizes[cid] = len(members) + valid_embeddings = [ + m["name_embedding"] for m in members if m.get("name_embedding") + ] + if valid_embeddings: + dim = len(valid_embeddings[0]) + community_embeddings[cid] = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + else: + community_embeddings[cid] = None + + # 找出应合并的社区对 + to_merge: List[tuple] = [] + cids = list(community_ids) + for i in range(len(cids)): + for j in range(i + 1, len(cids)): + sim = _cosine_similarity( + community_embeddings[cids[i]], + community_embeddings[cids[j]], + ) + if sim > MERGE_THRESHOLD: + to_merge.append((cids[i], cids[j])) + + logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区") + + # 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量 + # 避免 union-find 链式传递导致语义不相关的社区被间接合并 + # (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并) + merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射 + + def get_root(x: str) -> str: + """路径压缩,找到 x 当前所属的根社区。""" + while x in merged_into: + merged_into[x] = merged_into.get(merged_into[x], merged_into[x]) + x = merged_into[x] + return x + + for c1, c2 in to_merge: + root1, root2 = get_root(c1), get_root(c2) + if root1 == root2: + continue + + # 用合并后的最新平均向量重新验证相似度 + # 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并 + current_sim = _cosine_similarity( + community_embeddings.get(root1), + community_embeddings.get(root2), + ) + if current_sim <= MERGE_THRESHOLD: + # 合并后向量已漂移,不再满足阈值,跳过 + logger.debug( + f"[Clustering] 跳过合并 {root1} ↔ {root2}," + f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}" + ) + continue + + keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2 + dissolve = root2 if keep == root1 else root1 + merged_into[dissolve] = keep + + members = await self.repo.get_community_members(dissolve, end_user_id) + for m in members: + await self.repo.assign_entity_to_community(m["id"], keep, end_user_id) + + # 合并后重新计算 keep 的平均向量(加权平均) + keep_emb = community_embeddings.get(keep) + dissolve_emb = community_embeddings.get(dissolve) + keep_size = community_sizes.get(keep, 0) + dissolve_size = community_sizes.get(dissolve, 0) + total_size = keep_size + dissolve_size + if keep_emb and dissolve_emb and total_size > 0: + dim = len(keep_emb) + community_embeddings[keep] = [ + (keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size + for i in range(dim) + ] + community_embeddings[dissolve] = None + + community_sizes[keep] = total_size + community_sizes[dissolve] = 0 + await self.repo.refresh_member_count(keep, end_user_id) + logger.info( + f"[Clustering] 社区合并: {dissolve} → {keep}," + f"相似度={current_sim:.3f},迁移 {len(members)} 个成员" + ) + + async def _flush_labels( + self, labels: Dict[str, str], end_user_id: str + ) -> None: + """将内存中的标签批量写入 Neo4j。""" + # 先创建所有唯一社区节点 + unique_communities = set(labels.values()) + for cid in unique_communities: + await self.repo.upsert_community(cid, end_user_id) + + # 再批量分配实体 + for entity_id, community_id in labels.items(): + await self.repo.assign_entity_to_community( + entity_id, community_id, end_user_id + ) + + # 刷新成员数 + for cid in unique_communities: + await self.repo.refresh_member_count(cid, end_user_id) + + async def _get_entity_embedding( + self, entity_id: str, end_user_id: str + ) -> Optional[List[float]]: + """查询单个实体的 name_embedding。""" + try: + result = await self.connector.execute_query( + "MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) " + "RETURN e.name_embedding AS name_embedding", + eid=entity_id, + uid=end_user_id, + ) + return result[0]["name_embedding"] if result else None + except Exception: + return None + + async def _generate_community_metadata( + self, community_id: str, end_user_id: str + ) -> None: + """ + 为社区生成并写入元数据:名称、摘要、核心实体。 + + - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) + - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 + """ + try: + members = await self.repo.get_community_members(community_id, end_user_id) + if not members: + return + + # 核心实体:按 activation_value 降序取 top-N + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else community_id[:8] + summary = f"包含实体:{', '.join(all_names)}" + + # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 + if self.llm_model_id: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + entity_list_str = "、".join(all_names) + prompt = ( + f"以下是一组语义相关的实体:{entity_list_str}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过50个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") + + await self.repo.update_community_metadata( + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + ) + logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") + except Exception as e: + logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") + + @staticmethod + def _new_community_id() -> str: + return str(uuid.uuid4()) diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index c082b314..f19902a2 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -53,6 +53,7 @@ class SimpleMCPClient: else: await self._connect_http() except Exception as e: + await self.disconnect() logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}") raise MCPConnectionError(f"连接失败: {e}") diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 0b828a8b..61faf6d4 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -247,7 +247,6 @@ class EndUserRepository: EndUser.user_summary: user_summary, EndUser.rag_tags: rag_tags, EndUser.rag_personas: rag_personas, - EndUser.storage_type: "rag", EndUser.rag_summary_updated_at: datetime.datetime.now(), }, synchronize_session=False @@ -286,7 +285,6 @@ class EndUserRepository: .update( { EndUser.memory_insight: memory_insight, - EndUser.storage_type: "rag", EndUser.memory_insight_updated_at: datetime.datetime.now(), }, synchronize_session=False diff --git a/api/app/repositories/implicit_emotions_storage_repository.py b/api/app/repositories/implicit_emotions_storage_repository.py index f0871b4b..b6c40b40 100644 --- a/api/app/repositories/implicit_emotions_storage_repository.py +++ b/api/app/repositories/implicit_emotions_storage_repository.py @@ -5,7 +5,7 @@ Implicit Emotions Storage Repository 事务由调用方控制,仓储层只使用 flush/refresh """ import logging -from datetime import date, datetime, timedelta, timezone +from datetime import date, datetime, timezone from typing import Generator, Optional @@ -177,22 +177,21 @@ class ImplicitEmotionsStorageRepository: if raw is None: continue try: - CST = timezone(timedelta(hours=8)) last_done = datetime.fromisoformat(raw) - # last_done 写入时已是 CST naive,直接使用,无需转换 - if last_done.tzinfo is not None: - last_done = last_done.astimezone(CST).replace(tzinfo=None) + # last_done 写入时已是 UTC aware(+00:00),确保有 tzinfo + if last_done.tzinfo is None: + last_done = last_done.replace(tzinfo=timezone.utc) if updated_at is None: yield end_user_id continue - # updated_at 数据库存的是 UTC naive,转为 CST naive 再比较 + # updated_at 数据库存的是 UTC naive,补上 UTC tzinfo 再比较 if updated_at.tzinfo is None: - updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None) + updated_at_utc = updated_at.replace(tzinfo=timezone.utc) else: - updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None) + updated_at_utc = updated_at.astimezone(timezone.utc) - if last_done > updated_at_cst: + if last_done > updated_at_utc: yield end_user_id except Exception as e: logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}") diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py new file mode 100644 index 00000000..f2f11f76 --- /dev/null +++ b/api/app/repositories/neo4j/community_repository.py @@ -0,0 +1,194 @@ +"""Community 节点仓库 + +管理 Neo4j 中 Community 节点及 BELONGS_TO_COMMUNITY 边的 CRUD 操作。 +""" + +import logging +from typing import Dict, List, Optional + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.repositories.neo4j.cypher_queries import ( + COMMUNITY_NODE_UPSERT, + ENTITY_JOIN_COMMUNITY, + ENTITY_LEAVE_ALL_COMMUNITIES, + GET_ENTITY_NEIGHBORS, + GET_ALL_ENTITIES_FOR_USER, + GET_COMMUNITY_MEMBERS, + GET_ALL_COMMUNITY_MEMBERS_BATCH, + GET_ALL_ENTITY_NEIGHBORS_BATCH, + CHECK_USER_HAS_COMMUNITIES, + UPDATE_COMMUNITY_MEMBER_COUNT, + UPDATE_COMMUNITY_METADATA, +) + +logger = logging.getLogger(__name__) + + +class CommunityRepository: + def __init__(self, connector: Neo4jConnector): + self.connector = connector + + async def upsert_community( + self, community_id: str, end_user_id: str, member_count: int = 0 + ) -> Optional[str]: + """创建或更新 Community 节点,返回 community_id。""" + try: + result = await self.connector.execute_query( + COMMUNITY_NODE_UPSERT, + community_id=community_id, + end_user_id=end_user_id, + member_count=member_count, + ) + return result[0]["community_id"] if result else None + except Exception as e: + logger.error(f"upsert_community failed: {e}") + return None + + async def assign_entity_to_community( + self, entity_id: str, community_id: str, end_user_id: str + ) -> bool: + """将实体关联到社区(先解除旧关联,再建立新关联)。""" + try: + await self.connector.execute_query( + ENTITY_LEAVE_ALL_COMMUNITIES, + entity_id=entity_id, + end_user_id=end_user_id, + ) + result = await self.connector.execute_query( + ENTITY_JOIN_COMMUNITY, + entity_id=entity_id, + community_id=community_id, + end_user_id=end_user_id, + ) + return bool(result) + except Exception as e: + logger.error(f"assign_entity_to_community failed: {e}") + return False + + async def get_entity_neighbors( + self, entity_id: str, end_user_id: str + ) -> List[Dict]: + """查询实体的直接邻居及其社区归属。""" + try: + return await self.connector.execute_query( + GET_ENTITY_NEIGHBORS, + entity_id=entity_id, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_entity_neighbors failed: {e}") + return [] + + async def get_all_entity_neighbors_batch( + self, end_user_id: str + ) -> Dict[str, List[Dict]]: + """一次性批量拉取该用户下所有实体的邻居,返回 {entity_id: [neighbors]} 字典。 + 用于全量聚类预加载,避免每个实体单独查询。""" + try: + rows = await self.connector.execute_query( + GET_ALL_ENTITY_NEIGHBORS_BATCH, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + eid = row["entity_id"] + neighbor = {k: v for k, v in row.items() if k != "entity_id"} + result.setdefault(eid, []).append(neighbor) + return result + except Exception as e: + logger.error(f"get_all_entity_neighbors_batch failed: {e}") + return {} + + async def get_all_entities(self, end_user_id: str) -> List[Dict]: + """拉取某用户下所有实体及其当前社区归属。""" + try: + return await self.connector.execute_query( + GET_ALL_ENTITIES_FOR_USER, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_all_entities failed: {e}") + return [] + + async def get_community_members( + self, community_id: str, end_user_id: str + ) -> List[Dict]: + """查询社区成员列表。""" + try: + return await self.connector.execute_query( + GET_COMMUNITY_MEMBERS, + community_id=community_id, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_community_members failed: {e}") + return [] + + async def get_all_community_members_batch( + self, community_ids: List[str], end_user_id: str + ) -> Dict[str, List[Dict]]: + """批量查询多个社区的成员,返回 {community_id: [members]} 字典。""" + try: + rows = await self.connector.execute_query( + GET_ALL_COMMUNITY_MEMBERS_BATCH, + community_ids=community_ids, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + cid = row["community_id"] + result.setdefault(cid, []).append(row) + return result + except Exception as e: + logger.error(f"get_all_community_members_batch failed: {e}") + return {} + + async def has_communities(self, end_user_id: str) -> bool: + """检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。""" + try: + result = await self.connector.execute_query( + CHECK_USER_HAS_COMMUNITIES, + end_user_id=end_user_id, + ) + return result[0]["community_count"] > 0 if result else False + except Exception as e: + logger.error(f"has_communities failed: {e}") + return False + + async def refresh_member_count( + self, community_id: str, end_user_id: str + ) -> int: + """重新统计并更新社区成员数,返回最新数量。""" + try: + result = await self.connector.execute_query( + UPDATE_COMMUNITY_MEMBER_COUNT, + community_id=community_id, + end_user_id=end_user_id, + ) + return result[0]["member_count"] if result else 0 + except Exception as e: + logger.error(f"refresh_member_count failed: {e}") + return 0 + + async def update_community_metadata( + self, + community_id: str, + end_user_id: str, + name: str, + summary: str, + core_entities: List[str], + ) -> bool: + """更新社区的名称、摘要和核心实体列表。""" + try: + result = await self.connector.execute_query( + UPDATE_COMMUNITY_METADATA, + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + ) + return bool(result) + except Exception as e: + logger.error(f"update_community_metadata failed: {e}") + return False diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 651c513f..48a5ac87 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1058,4 +1058,147 @@ Graph_Node_query = """ 3 AS priority LIMIT $limit - """ \ No newline at end of file + """ + + +# ============================================================ +# Community 节点 & BELONGS_TO_COMMUNITY 边 +# ============================================================ + +# ─── Community 聚类相关 Cypher 模板 ─────────────────────────────────────────── + +COMMUNITY_NODE_UPSERT = """ +MERGE (c:Community {community_id: $community_id}) +SET c.end_user_id = $end_user_id, + c.member_count = $member_count, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" + +ENTITY_JOIN_COMMUNITY = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c) +SET c.updated_at = datetime() +RETURN e.id AS entity_id, c.community_id AS community_id +""" + +ENTITY_LEAVE_ALL_COMMUNITIES = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) +MATCH (e)-[r:BELONGS_TO_COMMUNITY]->(:Community) +DELETE r +""" + +GET_ENTITY_NEIGHBORS = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) + +// 来源一:直接关系邻居(EXTRACTED_RELATIONSHIP 边) +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源二:同 Statement 共现邻居(REFERENCES_ENTITY 边) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id + +WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH nb WHERE nb IS NOT NULL +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + +GET_ALL_ENTITIES_FOR_USER = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN e.id AS id, + e.name AS name, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + +GET_COMMUNITY_MEMBERS = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) +RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, + e.importance_score AS importance_score, e.activation_value AS activation_value, + e.name_embedding AS name_embedding +ORDER BY coalesce(e.activation_value, 0) DESC +""" + +GET_ALL_COMMUNITY_MEMBERS_BATCH = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) +WHERE c.community_id IN $community_ids +RETURN c.community_id AS community_id, + e.id AS id, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value +""" + +CHECK_USER_HAS_COMMUNITIES = """ +MATCH (c:Community {end_user_id: $end_user_id}) +RETURN count(c) AS community_count +""" + +UPDATE_COMMUNITY_MEMBER_COUNT = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) +WITH c, count(e) AS cnt +SET c.member_count = cnt +RETURN c.community_id AS community_id, cnt AS member_count +""" + +UPDATE_COMMUNITY_METADATA = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +SET c.name = $name, + c.summary = $summary, + c.core_entities = $core_entities, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" + +GET_ALL_ENTITY_NEIGHBORS_BATCH = """ +// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载) +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源一:直接关系邻居 +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源二:同 Statement 共现邻居 +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id + +WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH e, nb WHERE nb IS NOT NULL +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + e.id AS entity_id, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + +GET_COMMUNITY_GRAPH_DATA = """ +MATCH (c:Community {end_user_id: $end_user_id}) +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c) +OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id}) +RETURN + elementId(c) AS c_id, + properties(c) AS c_props, + elementId(e) AS e_id, + properties(e) AS e_props, + elementId(b) AS b_id, + elementId(e2) AS e2_id, + properties(e2) AS e2_props, + elementId(r) AS r_id, + type(r) AS r_type, + properties(r) AS r_props, + startNode(r) = e AS r_from_e +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 526d16ec..cbd2b532 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,4 +1,6 @@ -from typing import List +import asyncio +import os +from typing import List, Optional # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -155,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j( entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -288,6 +292,10 @@ async def save_dialog_and_statements_to_neo4j( } logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) + + # 写入成功后,异步触发聚类(不阻塞写入响应) + schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id) + return True except Exception as e: @@ -295,3 +303,55 @@ async def save_dialog_and_statements_to_neo4j( print(f"Neo4j integration error: {e}") print("Continuing without database storage...") return False + + +def schedule_clustering_after_write( + entity_nodes: List, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, +) -> None: + """ + 写入 Neo4j 成功后,调度后台聚类任务。 + + 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。 + 使用 asyncio.create_task 异步触发,不阻塞写入响应。 + """ + if not entity_nodes: + return + + clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" + if not clustering_enabled: + logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") + return + + end_user_id = entity_nodes[0].end_user_id + new_entity_ids = [e.id for e in entity_nodes] + logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) + + +async def _trigger_clustering( + new_entity_ids: List[str], + end_user_id: str, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, +) -> None: + """ + 聚类触发函数,自动判断全量初始化还是增量更新。 + """ + connector = None + try: + from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine + logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") + connector = Neo4jConnector() + engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) + await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) + logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") + except Exception as e: + logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True) + finally: + if connector: + try: + await connector.close() + except Exception: + pass diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 26a7390b..b4efe61d 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -26,5 +26,7 @@ class AgentMemory_Long_Term(ABC): STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 TIME_SCOPE=5 - +class AgentMemoryDataset(ABC): + PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余'] + NAME='用户' diff --git a/api/app/services/app_dsl_service.py b/api/app/services/app_dsl_service.py index d12d1009..a10aa70a 100644 --- a/api/app/services/app_dsl_service.py +++ b/api/app/services/app_dsl_service.py @@ -19,6 +19,7 @@ from app.models.tool_model import ToolConfig as ToolConfigModel from app.models.workflow_model import WorkflowConfig from app.services.workflow_service import WorkflowService from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter +from app.models.memory_config_model import MemoryConfig as MemoryConfigModel class AppDslService: @@ -423,9 +424,19 @@ class AppDslService: config_id = memory.get("memory_config_id") or memory.get("memory_content") if not config_id: return memory - from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + try: + config_uuid = uuid.UUID(str(config_id)) + except (ValueError, AttributeError): + exists = self.db.query(MemoryConfigModel).filter( + MemoryConfigModel.config_id_old == int(config_id), + MemoryConfigModel.workspace_id == workspace_id + ).first() + if not exists: + warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置") + return {**memory, "memory_config_id": None, "enabled": False} + return memory exists = self.db.query(MemoryConfigModel).filter( - MemoryConfigModel.config_id == config_id, + MemoryConfigModel.config_id == config_uuid, MemoryConfigModel.workspace_id == workspace_id ).first() if not exists: diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index b3b136a1..619a5f10 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -98,7 +98,7 @@ def create_long_term_memory_tool( **重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。** Args: - question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词) + question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词,第三人称描述的偏好、行为通常指用户本人,比如(我,本人,在下,自己,咱,鄙人,吴,余)通指用户) Returns: 检索到的历史记忆内容 diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index db49c50a..be656acb 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -535,7 +535,8 @@ def get_users_total_chunk_batch( def get_rag_content( end_user_id: str, - limit: int, + page: int, + pagesize: int, db: Session, current_user: User ) -> dict: @@ -543,9 +544,9 @@ def get_rag_content( 先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id, 然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容, 接着对获取的内容进行提取,只要page_content的内容, - 最后返回数据 + 最后返回分页数据 """ - business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}") + business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}") try: from app.models.document_model import Document @@ -562,63 +563,76 @@ def get_rag_content( if not documents: business_logger.warning(f"未找到文件: {file_name}") return { - "total": 0, - "contents": [] + "page": { + "page": page, + "pagesize": pagesize, + "total": 0, + "hasnext": False, + }, + "items": [] } business_logger.info(f"找到 {len(documents)} 个文档记录") - # 3. 获取所有chunks的page_content - all_contents = [] - total_chunks = 0 + # 3. 按全局偏移量计算当前页数据 + # 全局偏移范围:[offset_start, offset_end) + offset_start = (page - 1) * pagesize + offset_end = offset_start + pagesize + + global_total = 0 # 所有文档的 chunk 总数 + page_contents = [] # 当前页的内容 for document in documents: try: - # 获取知识库信息 kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id) if not kb: business_logger.warning(f"知识库不存在: kb_id={document.kb_id}") continue - # 初始化向量服务 vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb) - # 获取该文档的所有chunks(分页获取) - page = 1 - pagesize = 100 # 每页100条 + # 先用 pagesize=1 获取该文档的 chunk 总数 + doc_total, _ = vector_service.search_by_segment( + document_id=str(document.id), + query=None, + pagesize=1, + page=1, + asc=True + ) - while True: - total, items = vector_service.search_by_segment( + doc_offset_start = global_total # 该文档在全局中的起始偏移 + doc_offset_end = global_total + doc_total # 该文档在全局中的结束偏移 + global_total += doc_total + + # 当前页与该文档无交集,跳过 + if doc_offset_end <= offset_start or doc_offset_start >= offset_end: + continue + + # 计算需要从该文档取的局部范围 + local_start = max(offset_start - doc_offset_start, 0) + local_end = min(offset_end - doc_offset_start, doc_total) + need_count = local_end - local_start + + # 换算成 ES 分页参数(ES page 从1开始) + es_page = (local_start // pagesize) + 1 + es_offset_in_page = local_start % pagesize + + fetched = [] + while len(fetched) < es_offset_in_page + need_count: + _, items = vector_service.search_by_segment( document_id=str(document.id), query=None, pagesize=pagesize, - page=page, + page=es_page, asc=True ) - if not items: break - - # 提取page_content - for item in items: - all_contents.append(item.page_content) - total_chunks += 1 - - # # 如果达到limit限制,直接返回 - # if limit > 0 and total_chunks >= limit: - # business_logger.info(f"已达到limit限制: {limit}") - # return { - # "total": total_chunks, - # "contents": all_contents[:limit] - # } - - # 检查是否还有下一页 - if page * pagesize >= total: - break - - page += 1 + fetched.extend(items) + es_page += 1 - business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks") + slice_items = fetched[es_offset_in_page: es_offset_in_page + need_count] + page_contents.extend([item.page_content for item in slice_items]) except Exception as e: business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}") @@ -626,11 +640,16 @@ def get_rag_content( # 4. 返回结果 result = { - "total": total_chunks, - "contents": all_contents[:limit] if limit > 0 else all_contents + "page": { + "page": page, + "pagesize": pagesize, + "total": global_total, + "hasnext": offset_end < global_total, + }, + "items": page_contents } - business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])} 条") + business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条") return result except Exception as e: @@ -730,8 +749,8 @@ async def generate_rag_profile( if not end_user: raise ValueError(f"end_user {end_user_id} 不存在") - rag_content = get_rag_content(end_user_id, limit, db, current_user) - chunks = rag_content.get("contents", []) + rag_content = get_rag_content(end_user_id, page=1, pagesize=limit, db=db, current_user=current_user) + chunks = rag_content.get("items", []) if not chunks: business_logger.warning(f"未找到chunk内容,无法生产RAG画像: end_user_id={end_user_id}") diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 8bacc112..d5d19e0d 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1727,6 +1727,150 @@ async def analytics_graph_data( # 辅助函数 +async def analytics_community_graph_data( + db: Session, + end_user_id: str, +) -> Dict[str, Any]: + """ + 获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。 + + Returns: + 包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致 + """ + try: + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + if not end_user: + return { + "nodes": [], "edges": [], + "statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}}, + "message": "用户不存在" + } + + # 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系 + from app.repositories.neo4j.cypher_queries import GET_COMMUNITY_GRAPH_DATA + rows = await _neo4j_connector.execute_query(GET_COMMUNITY_GRAPH_DATA, end_user_id=end_user_id) + + nodes_map: Dict[str, dict] = {} + edges_map: Dict[str, dict] = {} + # 记录每个 Community 对应的实体 id 列表 + community_members: Dict[str, list] = {} + + for row in rows: + # Community 节点 + c_id = row["c_id"] + if c_id and c_id not in nodes_map: + raw = row["c_props"] or {} + props = {k: _clean_neo4j_value(raw.get(k)) for k in ( + "community_id", "end_user_id", "member_count", "updated_at", + "name", "summary", "core_entities", + ) if k in raw} + nodes_map[c_id] = { + "id": c_id, + "label": "Community", + "properties": props, + } + + # ExtractedEntity 节点 (e) + e_id = row["e_id"] + if e_id and e_id not in nodes_map: + raw = row["e_props"] or {} + props = {k: _clean_neo4j_value(raw.get(k)) for k in ( + "name", "end_user_id", "description", "created_at", "entity_type", + ) if k in raw} + # 注入所属社区名称(c 是 e 直接归属的社区) + c_raw = row["c_props"] or {} + props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or "" + nodes_map[e_id] = { + "id": e_id, + "label": "ExtractedEntity", + "properties": props, + } + + # ExtractedEntity 节点 (e2,可选) + e2_id = row.get("e2_id") + if e2_id and e2_id not in nodes_map: + raw = row["e2_props"] or {} + props = {k: _clean_neo4j_value(raw.get(k)) for k in ( + "name", "end_user_id", "description", "created_at", "entity_type", + ) if k in raw} + # e2 的社区归属在后处理阶段通过 community_members 补充 + props["community_name"] = "" + nodes_map[e2_id] = { + "id": e2_id, + "label": "ExtractedEntity", + "properties": props, + } + + # BELONGS_TO_COMMUNITY 边 + b_id = row["b_id"] + if b_id and b_id not in edges_map: + edges_map[b_id] = { + "id": b_id, + "source": e_id, + "target": c_id, + } + # 收集社区成员 id + if c_id and e_id: + community_members.setdefault(c_id, []) + if e_id not in community_members[c_id]: + community_members[c_id].append(e_id) + + # EXTRACTED_RELATIONSHIP 边(可选) + r_id = row.get("r_id") + if r_id and r_id not in edges_map and e2_id: + r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()} + source = e_id if row.get("r_from_e") else e2_id + target = e2_id if row.get("r_from_e") else e_id + edges_map[r_id] = { + "id": r_id, + "source": source, + "target": target, + } + + nodes = list(nodes_map.values()) + edges = list(edges_map.values()) + + # 为每个 Community 节点注入 member_entity_ids,同时补全 e2 节点的 community_name + for c_id, member_ids in community_members.items(): + c_node = nodes_map.get(c_id) + if c_node: + c_node["properties"]["member_entity_ids"] = member_ids + c_name = c_node["properties"].get("name") or "" + # 补全属于该社区但 community_name 为空的实体(即 e2 节点) + for eid in member_ids: + e_node = nodes_map.get(eid) + if e_node and e_node["label"] == "ExtractedEntity": + if not e_node["properties"].get("community_name"): + e_node["properties"]["community_name"] = c_name + + node_type_counts: Dict[str, int] = {} + for n in nodes: + node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1 + + return { + "nodes": nodes, + "edges": edges, + "statistics": { + "total_nodes": len(nodes), + "total_edges": len(edges), + "node_types": node_type_counts, + } + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "nodes": [], "edges": [], + "statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}}, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True) + raise + + async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]: """ 根据节点类型提取需要的属性字段 diff --git a/api/app/tasks.py b/api/app/tasks.py index 5e1550bd..cae3719b 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1158,13 +1158,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: _r = get_sync_redis_client() if _r is not None: - from datetime import timedelta as _td from datetime import timezone as _tz - _CST = _tz(_td(hours=8)) - _now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat() + _now_utc = datetime.now(_tz.utc).isoformat() _r.set( f"write_message:last_done:{end_user_id}", - _now_cst, + _now_utc, ex=86400 * 30, ) except Exception as _e: @@ -2662,3 +2660,134 @@ def write_perceptual_memory( file_url, file_message, )) + + +# ============================================================================= +# 社区聚类补全任务(触发型) +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_community_clustering_for_users", + bind=True, + ignore_result=False, + max_retries=0, + acks_late=False, + time_limit=7200, # 2小时硬超时 + soft_time_limit=6900, +) +def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。 + + 由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。 + + Args: + end_user_ids: 需要检查的用户 ID 列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.neo4j.community_repository import CommunityRepository + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + + logger = get_logger(__name__) + logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}") + + initialized = 0 + skipped = 0 + failed = 0 + + connector = Neo4jConnector() + try: + repo = CommunityRepository(connector) + + # 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置) + user_llm_map: Dict[str, Optional[str]] = {} + try: + with get_db_context() as db: + from app.services.memory_agent_service import get_end_users_connected_configs_batch + from app.services.memory_config_service import MemoryConfigService + batch_configs = get_end_users_connected_configs_batch(end_user_ids, db) + for uid, cfg_info in batch_configs.items(): + config_id = cfg_info.get("memory_config_id") + if config_id: + try: + cfg = MemoryConfigService(db).load_memory_config(config_id=config_id) + user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None + except Exception as e: + logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}") + user_llm_map[uid] = None + else: + user_llm_map[uid] = None + except Exception as e: + logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}") + + for end_user_id in end_user_ids: + try: + # 已有社区节点则跳过 + has_communities = await repo.has_communities(end_user_id) + if has_communities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过") + continue + + # 检查是否有 ExtractedEntity 节点 + entities = await repo.get_all_entities(end_user_id) + if not entities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过") + continue + + # 每个用户使用自己的 llm_model_id + llm_model_id = user_llm_map.get(end_user_id) + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + ) + + logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") + await engine.full_clustering(end_user_id) + initialized += 1 + logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") + + except Exception as e: + failed += 1 + logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}") + + finally: + await connector.close() + + logger.info( + f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}" + ) + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + loop = set_asyncio_event_loop() + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } diff --git a/api/app/version_info.json b/api/app/version_info.json index bbaffc17..12793cb5 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,38 @@ { + "v0.2.7": { + "introduction": { + "codeName": "武陵", + "releaseDate": "2026-3-13", + "upgradePosition": "🐻 应用可移植性、工具生态扩展与记忆智能精细化", + "coreUpgrades": [ + "1. 应用管理与可移植性
* 应用导入/导出:全面支持 Agent 配置和工作流定义的导入导出,实现跨环境无缝迁移、备份和共享", + "2. 工具生态扩展 🔌
* MCP 广场集成:工具管理接入 MCP 广场,提供集中式工具发现、浏览和集成枢纽", + "3. 工作流增强 📝
* 备注节点:新增备注节点类型,支持工作流图中的内联文档和上下文说明,提升协作效率", + "4. 记忆智能精细化 🧠
* 隐性记忆与情绪记忆生成逻辑优化:含数据存在性校验、时间轴筛选和兴趣分布缓存校验
* 兴趣分布生成逻辑改进:优化算法产生更准确的用户兴趣画像", + "5. 用户体验改进 🎨
* 知识库分享加载状态:增加加载指示器,改善感知响应速度", + "6. 稳健性与缺陷修复 🔧
* 应用调试终端用户管理:修复调试会话错误创建 end_user 记录问题
* 知识库数据集创建流程:解决创建数据集后无法进入下一步的缺陷
* RAG 空间记忆生成失败:修复记忆生成失败和存储中断的关键问题
* 应用字符限制强制执行:增加条件校验防止过长输入
* 语义剪枝情绪/兴趣保留:优化剪枝逻辑防止误删情绪和兴趣片段
* 语义剪枝效果优化:增强算法平衡记忆压缩与信息保留", + "
", + "v0.2.8 及更远的未来将引入多模态记忆能力,实现知识库和模型的分服务部署,为应用增加语音输入支持,并扩展应用能力至语音回复、BI 可视化、PPT 生成和直接生图。应用会话分享和联网搜索功能将得到修复和增强。记忆检索基准测试和情景记忆聚类算法将增强上下文召回和时序推理能力。通往真正智能、多模态、上下文感知应用的旅程仍在继续。", + "记忆熊,智慧致远 🐻✨" + ] + }, + "introduction_en": { + "codeName": "WuLing", + "releaseDate": "2026-3-13", + "upgradePosition": "🐻 Application portability, tool ecosystem expansion, and memory intelligence refinement", + "coreUpgrades": [ + "1. Application Management & Portability
* Application Import/Export: Full support for importing and exporting agent configurations and workflow definitions, enabling seamless cross-environment migration, backup, and sharing", + "2. Tool Ecosystem Expansion 🔌
* MCP Marketplace Integration: Tool management now includes MCP Marketplace access for centralized tool discovery, browsing, and integration", + "3. Workflow Enhancements 📝
* Annotation Node: Introduced annotation node type for inline documentation and contextual notes within workflow graphs, improving collaboration", + "4. Memory Intelligence Refinement 🧠
* Implicit & Emotional Memory Generation Logic: Comprehensive optimization including data existence validation, timeline filtering, and interest distribution cache validation
* Interest Distribution Generation Logic: Refined algorithm for more accurate user interest profiles", + "5. User Experience Improvements 🎨
* Knowledge Base Sharing Loading State: Added loading indicators to improve perceived responsiveness", + "6. Robustness & Bug Fixes 🔧
* End User Management in App Debugging: Fixed incorrect end_user record creation during debugging sessions
* Knowledge Base Dataset Creation Flow: Resolved bug preventing next step after dataset creation
* RAG Space Memory Generation Failure: Fixed critical memory generation and storage interruption issue
* Application Character Limit Enforcement: Added conditional validation to prevent excessively long input
* Semantic Pruning Emotion/Interest Preservation: Optimized pruning logic to prevent incorrect deletion of emotional and interest fragments
* Semantic Pruning Effectiveness: Enhanced algorithm balance between memory compression and information retention", + "
", + "Looking forward to v0.2.8 and beyond, we will introduce multimodal memory capabilities with distributed service deployment for knowledge bases and models, enabling voice input for applications and expanding application capabilities with voice responses, BI visualizations, PPT generation, and direct image creation. Application conversation sharing and web search functionality will be restored and enhanced. Memory retrieval benchmarking and episodic memory clustering algorithms will enhance contextual recall and temporal reasoning. The journey toward truly intelligent, multimodal, context-aware applications continues.", + "MemoryBear, Wisdom Reaching Far 🐻✨" + ] + } + }, "v0.2.6": { "introduction": { "codeName": "听剑", diff --git a/api/migrations/versions/ef9d172cb753_202603131800.py b/api/migrations/versions/ef9d172cb753_202603131800.py new file mode 100644 index 00000000..efeaee1c --- /dev/null +++ b/api/migrations/versions/ef9d172cb753_202603131800.py @@ -0,0 +1,30 @@ +"""202603131800 + +Revision ID: ef9d172cb753 +Revises: ea31b4e347d8 +Create Date: 2026-03-13 18:01:11.167711 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ef9d172cb753' +down_revision: Union[str, None] = 'ea31b4e347d8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('app_shares', sa.Column('is_active', sa.Boolean(), server_default='true', nullable=False, comment='是否有效,False 表示逻辑删除')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('app_shares', 'is_active') + # ### end Alembic commands ### diff --git a/web/package.json b/web/package.json index 2799a631..b9e3709e 100644 --- a/web/package.json +++ b/web/package.json @@ -44,6 +44,7 @@ "i18next": "^25.6.0", "js-yaml": "^4.1.1", "lexical": "^0.39.0", + "mammoth": "^1.12.0", "mermaid": "^11.12.1", "react": "^18.2.0", "react-dom": "^18.2.0", @@ -59,6 +60,7 @@ "remark-gfm": "^4.0.1", "remark-math": "^6.0.0", "tailwindcss": "^4.1.14", + "xlsx": "^0.18.5", "zustand": "^5.0.8" }, "devDependencies": { diff --git a/web/src/api/application.ts b/web/src/api/application.ts index 71048454..6035afe2 100644 --- a/web/src/api/application.ts +++ b/web/src/api/application.ts @@ -2,11 +2,11 @@ * @Author: ZhaoYing * @Date: 2026-02-03 13:59:45 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-03 12:08:42 + * @Last Modified time: 2026-03-13 17:07:54 */ import { request } from '@/utils/request' import type { ApplicationModalData } from '@/views/ApplicationManagement/types' -import type { Config } from '@/views/ApplicationConfig/types' +import type { Config, AppSharingForm } from '@/views/ApplicationConfig/types' import { handleSSE, type SSEMessage } from '@/utils/stream' import type { QueryParams } from '@/views/Conversation/types' import type { WorkflowConfig } from '@/views/Workflow/types' @@ -113,8 +113,8 @@ export const getShareToken = (share_token: string, user_id: string) => { return request.post(`/public/share/${share_token}/token`, { user_id }) } // Copy application -export const copyApplication = (app_id: string, new_name: string) => { - return request.post(`/apps/${app_id}/copy?new_name=${new_name}`) +export const copyApplication = (app_id: string, new_name?: string) => { + return request.post(`/apps/${app_id}/copy`, { new_name }) } // Data statistics export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => { @@ -143,4 +143,26 @@ export const appExport = (app_id: string, appName: string, data?: { release_vers // Import application export const appImport = (formData: FormData) => { return request.uploadFile(`/apps/import`, formData) -} \ No newline at end of file +} + +// Share application +export const appSharing = (app_id: string, data: AppSharingForm) => { + return request.post(`/apps/${app_id}/share`, data) +} +// Get my shared application records +export const mySharedOutList = () => { + return request.get(`/apps/my-shared-out`) +} +// Get sharing records for a specific application +export const getAppShares = (app_id: string) => { + return request.get(`/apps/${app_id}/shares`) +} +// Cancel a single share (source side operation) +export const cancelShare = (app_id: string, target_workspace_id?: string) => { + return request.delete(`/apps/${app_id}/share/${target_workspace_id}`) +} +// Cancel all shares under a workspace (source side operation) +export const cancelSpaceShare = (target_workspace_id?: string) => { + return request.delete(`/apps/share/${target_workspace_id}`) +} + diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index b8bfac32..823e3d78 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -123,8 +123,9 @@ export const getChunkInsight = (end_user_id: string) => { return request.get(`/dashboard/chunk_insight`, { end_user_id }) } // RAG User Memory - Storage content -export const getRagContent = (end_user_id: string) => { - return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 }) +export const getRagContentUrl = '/dashboard/rag_content' +export const getRagContent = (end_user_id: string, page = 1, pagesize = 20) => { + return request.get(getRagContentUrl, { end_user_id, page, pagesize }) } // Emotion distribution analysis export const getWordCloud = (end_user_id: string) => { diff --git a/web/src/api/tools.ts b/web/src/api/tools.ts index 7c7a0e3d..612f924d 100644 --- a/web/src/api/tools.ts +++ b/web/src/api/tools.ts @@ -6,12 +6,12 @@ export const getTools = (data: Query) => { return request.get('/tools', data) } // 创建MCP工具 -export const addTool = (values: MCPToolItem | CustomToolItem) => { - return request.post('/tools', values) +export const addTool = (values: MCPToolItem | CustomToolItem, config?: { signal?: AbortSignal }) => { + return request.post('/tools', values, config) } // 更新工具 -export const updateTool = (tool_id: string, data: MCPToolItem | InnerToolItem | CustomToolItem) => { - return request.put(`/tools/${tool_id}`, data) +export const updateTool = (tool_id: string, data: MCPToolItem | InnerToolItem | CustomToolItem, config?: { signal?: AbortSignal }) => { + return request.put(`/tools/${tool_id}`, data, config) } // 删除工具 export const deleteTool = (tool_id: string) => { diff --git a/web/src/api/workspaces.ts b/web/src/api/workspaces.ts index 01f3be72..5c62489d 100644 --- a/web/src/api/workspaces.ts +++ b/web/src/api/workspaces.ts @@ -1,16 +1,16 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 14:00:26 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 14:00:26 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-13 15:29:03 */ import { request } from '@/utils/request' import type { SpaceModalData } from '@/views/SpaceManagement/types' import type { SpaceConfigData } from '@/views/SpaceConfig/types' // Workspace list -export const getWorkspaces = () => { - return request.get('/workspaces') +export const getWorkspaces = (data?: { include_current?: boolean }) => { + return request.get('/workspaces', data) } // Create workspace export const createWorkspace = (values: SpaceModalData) => { diff --git a/web/src/assets/images/file/audio.svg b/web/src/assets/images/file/audio.svg new file mode 100644 index 00000000..0826c7f8 --- /dev/null +++ b/web/src/assets/images/file/audio.svg @@ -0,0 +1,11 @@ + + + 音乐 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/csv.svg b/web/src/assets/images/file/csv.svg new file mode 100644 index 00000000..1b8fc721 --- /dev/null +++ b/web/src/assets/images/file/csv.svg @@ -0,0 +1,16 @@ + + + 编组 57 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/excel.svg b/web/src/assets/images/file/excel.svg new file mode 100644 index 00000000..cd09cc8c --- /dev/null +++ b/web/src/assets/images/file/excel.svg @@ -0,0 +1,15 @@ + + + Excel + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/html.svg b/web/src/assets/images/file/html.svg new file mode 100644 index 00000000..641f97a2 --- /dev/null +++ b/web/src/assets/images/file/html.svg @@ -0,0 +1,15 @@ + + + Word + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/image.svg b/web/src/assets/images/file/image.svg new file mode 100644 index 00000000..f81baa50 --- /dev/null +++ b/web/src/assets/images/file/image.svg @@ -0,0 +1,15 @@ + + + 编组 58 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/json.svg b/web/src/assets/images/file/json.svg new file mode 100644 index 00000000..4ced0745 --- /dev/null +++ b/web/src/assets/images/file/json.svg @@ -0,0 +1,12 @@ + + + JSON + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/md.svg b/web/src/assets/images/file/md.svg new file mode 100644 index 00000000..c2cb9619 --- /dev/null +++ b/web/src/assets/images/file/md.svg @@ -0,0 +1,17 @@ + + + PDF + + + + + + + + + MD + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/pdf.svg b/web/src/assets/images/file/pdf.svg new file mode 100644 index 00000000..10c3020b --- /dev/null +++ b/web/src/assets/images/file/pdf.svg @@ -0,0 +1,18 @@ + + + PDF + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/ppt.svg b/web/src/assets/images/file/ppt.svg new file mode 100644 index 00000000..eb3d4d8d --- /dev/null +++ b/web/src/assets/images/file/ppt.svg @@ -0,0 +1,12 @@ + + + file-ppt-2-fill + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/txt.svg b/web/src/assets/images/file/txt.svg new file mode 100644 index 00000000..141d2bfb --- /dev/null +++ b/web/src/assets/images/file/txt.svg @@ -0,0 +1,12 @@ + + + txt + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/video.svg b/web/src/assets/images/file/video.svg new file mode 100644 index 00000000..08c0b262 --- /dev/null +++ b/web/src/assets/images/file/video.svg @@ -0,0 +1,14 @@ + + + 编组 59 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/word.svg b/web/src/assets/images/file/word.svg new file mode 100644 index 00000000..dc37637d --- /dev/null +++ b/web/src/assets/images/file/word.svg @@ -0,0 +1,15 @@ + + + Word + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/AudioRecorder/index.tsx b/web/src/components/AudioRecorder/index.tsx index d31746f6..10b8eca9 100644 --- a/web/src/components/AudioRecorder/index.tsx +++ b/web/src/components/AudioRecorder/index.tsx @@ -1,13 +1,23 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-06 21:11:51 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-13 17:11:14 + */ import { type FC, useRef, useState } from 'react' import RecordRTC from 'recordrtc' import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import { request } from '@/utils/request' +/** Props for the AudioRecorder component */ interface AudioRecorderProps { + /** Callback fired when recording is complete, receives uploaded file info and raw blob */ onRecordingComplete?: (file: { file_id: string; file_key: string; url: string; type?: string; }, blob?: Blob) => void className?: string; + /** Upload endpoint URL, defaults to fileUploadUrlWithoutApiPrefix */ action?: string; + /** Additional config passed to the upload request */ requestConfig?: Record; } @@ -17,9 +27,12 @@ const AudioRecorder: FC = ({ action = fileUploadUrlWithoutApiPrefix, requestConfig = {} }) => { + // Whether the recorder is currently capturing audio const [isRecording, setIsRecording] = useState(false) + // Holds the RecordRTC instance across renders const recorderRef = useRef(null) + /** Request microphone access and start recording */ const startRecording = async () => { try { const stream = await navigator.mediaDevices.getUserMedia({ audio: true }) @@ -34,6 +47,7 @@ const AudioRecorder: FC = ({ } } + /** Stop recording, upload the audio blob, then invoke the completion callback */ const stopRecording = () => { if (recorderRef.current) { recorderRef.current.stopRecording(() => { @@ -49,6 +63,7 @@ const AudioRecorder: FC = ({ type: blob.type, url }, blob) + // Release recorder resources after upload recorderRef.current?.destroy() recorderRef.current = null }) @@ -57,12 +72,14 @@ const AudioRecorder: FC = ({ } } + // Toggle between recording/idle states on click; + // swap background image to reflect current state return (
diff --git a/web/src/components/ButtonCheckbox/index.tsx b/web/src/components/ButtonCheckbox/index.tsx index 4b43f18a..81396648 100644 --- a/web/src/components/ButtonCheckbox/index.tsx +++ b/web/src/components/ButtonCheckbox/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-02 15:01:59 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 15:46:05 + * @Last Modified time: 2026-03-12 14:59:38 */ /** @@ -15,7 +15,7 @@ */ import { type FC, type ReactNode, useEffect } from 'react'; -import { type RadioGroupProps } from 'antd'; +import { type RadioGroupProps, Flex } from 'antd'; import clsx from 'clsx' // Button checkbox component props @@ -32,6 +32,7 @@ interface ButtonCheckboxProps extends Omit { checkedIcon?: string; /** Button content */ children?: ReactNode + cicle?: boolean; } const ButtonCheckbox: FC = ({ @@ -41,6 +42,7 @@ const ButtonCheckbox: FC = ({ icon, checkedIcon, children, + cicle = false }) => { // Listen to value changes and trigger side effects via onValueChange callback useEffect(() => { @@ -57,21 +59,26 @@ const ButtonCheckbox: FC = ({ } return ( -
{/* Display unchecked icon when not checked */} - {icon && !checked && } + {icon && !checked && } {/* Display checked icon when checked */} {checkedIcon && checked && } {children} -
+ ); }; diff --git a/web/src/components/Chat/index.tsx b/web/src/components/Chat/index.tsx index 9a60918a..9a49b0f7 100644 --- a/web/src/components/Chat/index.tsx +++ b/web/src/components/Chat/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2025-12-10 16:46:09 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-06 21:05:09 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-12 13:57:49 */ import { type FC } from 'react' import ChatInput from './ChatInput' @@ -25,7 +25,8 @@ const Chat: FC = ({ labelFormat, errorDesc, fileList, - fileChange + fileChange, + renderRuntime }) => { return (
@@ -37,6 +38,7 @@ const Chat: FC = ({ empty={empty} labelFormat={labelFormat} errorDesc={errorDesc} + renderRuntime={renderRuntime} /> {/* Chat input area */} diff --git a/web/src/components/Chat/types.ts b/web/src/components/Chat/types.ts index 0cf1b130..e8e00bd9 100644 --- a/web/src/components/Chat/types.ts +++ b/web/src/components/Chat/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:45:54 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-06 21:05:09 + * @Last Modified time: 2026-03-12 13:57:51 */ import { type ReactNode } from 'react' @@ -53,6 +53,7 @@ export interface ChatProps { fileList?: any[]; /** Attachment update */ fileChange?: (fileList: any[]) => void; + renderRuntime?: (item: ChatItem, index: number) => ReactNode; } /** diff --git a/web/src/components/DocumentPreview/index.tsx b/web/src/components/DocumentPreview/index.tsx index 404d6e50..c57080fb 100644 --- a/web/src/components/DocumentPreview/index.tsx +++ b/web/src/components/DocumentPreview/index.tsx @@ -1,20 +1,18 @@ import { useState, useEffect, type FC } from 'react'; -import { Spin, Alert, Button } from 'antd'; -import { ReloadOutlined } from '@ant-design/icons'; +import { Spin, Alert, Button, Table } from 'antd'; +import { ReloadOutlined, DownloadOutlined } from '@ant-design/icons'; import RbMarkdown from '../Markdown'; -import { cookieUtils } from '@/utils/request' - -type PreviewMode = 'office' | 'google'; +import { cookieUtils } from '@/utils/request'; +import mammoth from 'mammoth'; +import * as XLSX from 'xlsx'; interface DocumentPreviewProps { fileUrl: string; fileName?: string; - fileExt?: string; // 文件扩展名(优先使用) + fileExt?: string; width?: string | number; height?: string | number; className?: string; - mode?: PreviewMode; // 预览模式 - showModeSwitch?: boolean; // 是否显示模式切换按钮 } const DocumentPreview: FC = ({ @@ -24,18 +22,19 @@ const DocumentPreview: FC = ({ width = '100%', height = '600px', className = '', - mode = 'office', - showModeSwitch = true, }) => { const [loading, setLoading] = useState(true); const [error, setError] = useState(false); - const [currentMode, setCurrentMode] = useState(mode); + const [errorMessage, setErrorMessage] = useState(''); const [textContent, setTextContent] = useState(''); + const [htmlContent, setHtmlContent] = useState(''); + const [excelData, setExcelData] = useState<{ sheetName: string; data: any[][] }[]>([]); - // 支持的文件类型 - const supportedTypes = ['.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.pdf', '.txt', '.md', '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp']; + // 支持预览的文件类型 + const previewableTypes = ['.pdf', '.txt', '.md', '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.doc', '.docx', '.xls', '.xlsx']; + // PPT 暂不支持 + const downloadOnlyTypes = ['.ppt', '.pptx']; - // 获取文件扩展名(优先使用 fileExt prop) const getFileExtension = () => { if (fileExt) { return fileExt.toLowerCase().startsWith('.') ? fileExt.toLowerCase() : `.${fileExt.toLowerCase()}`; @@ -45,67 +44,25 @@ const DocumentPreview: FC = ({ return match ? `.${match[1].toLowerCase()}` : ''; }; - // 检查是否为文本文件 - const isTextFile = () => { - const ext = getFileExtension(); - return ext === '.txt'; - }; - - // 检查是否为 Markdown 文件 - const isMarkdownFile = () => { - const ext = getFileExtension(); - return ext === '.md'; - }; - - // 检查是否为图片文件 + const isTextFile = () => getFileExtension() === '.txt'; + const isMarkdownFile = () => getFileExtension() === '.md'; const isImageFile = () => { - const ext = getFileExtension(); const imageExts = ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp']; - return imageExts.includes(ext); - }; - - // 检查文件类型是否支持 - const isSupportedFile = () => { - const ext = getFileExtension(); - return ext && supportedTypes.includes(ext); + return imageExts.includes(getFileExtension()); }; + const isPdfFile = () => getFileExtension() === '.pdf'; + const isWordFile = () => ['.doc', '.docx'].includes(getFileExtension()); + const isExcelFile = () => ['.xls', '.xlsx'].includes(getFileExtension()); + const isPreviewable = () => previewableTypes.includes(getFileExtension()); + const isDownloadOnly = () => downloadOnlyTypes.includes(getFileExtension()); - // 检查是否为 PDF 文件 - const isPdfFile = () => { - const ext = getFileExtension(); - return ext === '.pdf'; - }; - - // 构建预览 URL - const getPreviewUrl = () => { - // 处理文件 URL,如果是完整的 URL,转换为代理路径 - let requestUrl = fileUrl; - - // 如果是完整的 https://devapi.mem.redbearai.com 开头的 URL,提取路径部分 - // 这样可以通过代理访问,避免 CORS 问题 - if (fileUrl.includes('devapi.mem.redbearai.com')) { - const url = new URL(fileUrl); - requestUrl = url.pathname; // 只取路径部分,例如 /api/files/xxx - } - - // 对于 PDF 文件,直接使用浏览器内置预览 - if (isPdfFile()) { - return requestUrl; - } - - // 确保 fileUrl 是完整的 URL(用于第三方预览服务) - let fullUrl = fileUrl; - if (!fileUrl.startsWith('http')) { - fullUrl = `${window.location.origin}${fileUrl.startsWith('/') ? '' : '/'}${fileUrl}`; - } - console.log('预览 URL:', fullUrl); - // 根据模式选择预览服务 - if (currentMode === 'google') { - return `https://docs.google.com/viewer?url=${encodeURIComponent(fullUrl)}&embedded=true`; - } - - // 默认使用 Microsoft Office Online Viewer - return `https://view.officeapps.live.com/op/embed.aspx?src=${encodeURIComponent(fullUrl)}`; + const handleDownload = () => { + const link = document.createElement('a'); + link.href = fileUrl; + link.download = fileName || 'document'; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); }; const handleLoad = () => { @@ -113,20 +70,24 @@ const DocumentPreview: FC = ({ setError(false); }; - const handleError = () => { + const handleError = (msg?: string) => { setLoading(false); setError(true); + if (msg) setErrorMessage(msg); }; const handleRetry = () => { setLoading(true); setError(false); + setErrorMessage(''); if (isTextFile() || isMarkdownFile()) { - // 重新加载文本文件 loadTextFile(); + } else if (isWordFile()) { + loadWordFile(); + } else if (isExcelFile()) { + loadExcelFile(); } else { - // 强制重新加载 iframe const iframe = document.querySelector(`iframe[title="${fileName || '文档预览'}"]`) as HTMLIFrameElement; if (iframe) { iframe.src = iframe.src; @@ -134,82 +95,164 @@ const DocumentPreview: FC = ({ } }; - const handleSwitchMode = () => { - setCurrentMode(prev => prev === 'office' ? 'google' : 'office'); - setLoading(true); - setError(false); - }; - - // 加载文本文件内容 const loadTextFile = async () => { setLoading(true); setError(false); + setErrorMessage(''); try { - // 处理文件 URL,如果是完整的 URL,转换为代理路径 let requestUrl = fileUrl; - // 如果是完整的 https://devapi.mem.redbearai.com 开头的 URL,提取路径部分 if (fileUrl.includes('devapi.mem.redbearai.com')) { const url = new URL(fileUrl); - requestUrl = url.pathname; // 只取路径部分,例如 /api/files/xxx + requestUrl = url.pathname; } const response = await fetch(requestUrl, { - credentials: 'include', // 包含认证信息 + credentials: 'include', headers: { 'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`, }, }); if (!response.ok) { - throw new Error('Failed to load file'); + throw new Error(`HTTP ${response.status}: ${response.statusText}`); } - // 检查响应的 Content-Type const contentType = response.headers.get('Content-Type') || ''; - console.log('文件 Content-Type:', contentType); - // 如果是图片类型,显示错误提示 if (contentType.startsWith('image/')) { - setError(true); - setTextContent(''); - setLoading(false); - console.error('文件实际是图片类型,但被标记为 txt'); + handleError('文件实际是图片类型,但被标记为文本文件'); return; } const text = await response.text(); - // 检查是否是二进制数据(如 PNG 文件头) if (text.startsWith('\x89PNG') || text.startsWith('�PNG')) { - setError(true); - setTextContent(''); - setLoading(false); - console.error('文件内容是 PNG 图片,但扩展名是 txt'); + handleError('文件内容是图片,但扩展名是文本'); return; } setTextContent(text); setLoading(false); - } catch (err) { + } catch (err: any) { console.error('加载文本文件失败:', err); - setError(true); - setLoading(false); + handleError(err.message || '加载文本文件失败'); + } + }; + + const loadWordFile = async () => { + setLoading(true); + setError(false); + setErrorMessage(''); + try { + let requestUrl = fileUrl; + + if (fileUrl.includes('devapi.mem.redbearai.com')) { + const url = new URL(fileUrl); + requestUrl = url.pathname; + } + + const response = await fetch(requestUrl, { + credentials: 'include', + headers: { + 'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`, + }, + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const arrayBuffer = await response.arrayBuffer(); + const result = await mammoth.convertToHtml({ arrayBuffer }); + setHtmlContent(result.value); + setLoading(false); + } catch (err: any) { + console.error('加载 Word 文件失败:', err); + handleError(err.message || '加载 Word 文件失败,文件可能已损坏'); + } + }; + + const loadExcelFile = async () => { + setLoading(true); + setError(false); + setErrorMessage(''); + try { + let requestUrl = fileUrl; + + if (fileUrl.includes('devapi.mem.redbearai.com')) { + const url = new URL(fileUrl); + requestUrl = url.pathname; + } + + const response = await fetch(requestUrl, { + credentials: 'include', + headers: { + 'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`, + }, + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const arrayBuffer = await response.arrayBuffer(); + const workbook = XLSX.read(arrayBuffer, { type: 'array' }); + + const sheets = workbook.SheetNames.map(sheetName => { + const worksheet = workbook.Sheets[sheetName]; + const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][]; + return { sheetName, data }; + }); + + setExcelData(sheets); + setLoading(false); + } catch (err: any) { + console.error('加载 Excel 文件失败:', err); + handleError(err.message || '加载 Excel 文件失败,文件可能已损坏'); } }; - // 当文件是 txt 或 md 时,加载文本内容 useEffect(() => { if (isTextFile() || isMarkdownFile()) { loadTextFile(); + } else if (isWordFile()) { + loadWordFile(); + } else if (isExcelFile()) { + loadExcelFile(); } }, [fileUrl]); - if (!isSupportedFile()) { + // PPT 文件只提供下载 + if (isDownloadOnly()) { + return ( +
+ +

PPT 文件暂不支持在线预览,请下载后查看

+ +
+ } + type="info" + showIcon + /> +
+ ); + } + + if (!isPreviewable()) { return ( @@ -230,23 +273,26 @@ const DocumentPreview: FC = ({ message="预览失败" description={
-

无法加载文档预览,可能的原因:

-
    -
  • 文件需要认证访问,Office 预览服务无法访问
  • -
  • 文件 URL 无法公开访问(需要配置公开访问或临时签名 URL)
  • -
  • 文件大小超过限制(Office 预览通常限制 10MB)
  • -
  • 预览服务暂时不可用
  • +

    无法加载文档预览

    + {errorMessage && ( +

    + 错误详情:{errorMessage} +

    + )} +

    可能的原因:

    +
      +
    • 文件 URL 无法访问(401/403/404)
    • +
    • 认证 token 已过期
    • +
    • 文件格式损坏或不匹配
    • +
    • 网络连接问题
    -

    建议:请下载文件到本地查看

    - {showModeSwitch && !isPdfFile() && ( - - )} +
} @@ -256,26 +302,23 @@ const DocumentPreview: FC = ({
)} - {/* 图片文件预览 */} {isImageFile() && !error && !loading && (
{fileName setError(true)} + onError={() => handleError('图片加载失败')} />
)} - {/* Markdown 文件预览 */} {isMarkdownFile() && !error && !loading && (
)} - {/* 文本文件预览 */} {isTextFile() && !error && !loading && (
@@ -284,44 +327,52 @@ const DocumentPreview: FC = ({
         
)} - {/* PDF 文件预览(使用浏览器内置预览) */} + {isWordFile() && !error && !loading && ( +
+
+
+ )} + + {isExcelFile() && !error && !loading && ( +
+ {excelData.map((sheet, index) => ( +
+

{sheet.sheetName}

+ {sheet.data.length > 0 && ( + ({ key: idx, ...row }))} + columns={sheet.data[0]?.map((header: any, colIdx: number) => ({ + title: header || `列 ${colIdx + 1}`, + dataIndex: colIdx, + key: colIdx, + width: 150, + })) || []} + pagination={false} + scroll={{ x: 'max-content' }} + size="small" + bordered + /> + )} + + ))} + + )} + {isPdfFile() && !error && !loading && (