diff --git a/.gitignore b/.gitignore index 66d1beb2..ae3261f0 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,8 @@ examples/ time.log celerybeat-schedule.db search_results.json +redbear-mem-metrics/ +pitch-deck/ api/migrations/versions tmp diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 585de2ed..451dcdf7 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -13,6 +13,7 @@ from . import ( document_controller, emotion_config_controller, emotion_controller, + end_user_controller, file_controller, file_storage_controller, home_page_controller, @@ -96,5 +97,6 @@ manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) manager_router.include_router(skill_controller.router) manager_router.include_router(i18n_controller.router) +manager_router.include_router(end_user_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/end_user_controller.py b/api/app/controllers/end_user_controller.py new file mode 100644 index 00000000..b9d54fea --- /dev/null +++ b/api/app/controllers/end_user_controller.py @@ -0,0 +1,48 @@ +"""End User 管理接口 - 无需认证""" + +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.repositories.end_user_repository import EndUserRepository +from app.schemas.memory_api_schema import ( + CreateEndUserRequest, + CreateEndUserResponse, +) +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +router = APIRouter(prefix="/end_users", tags=["End Users"]) +logger = get_business_logger() + + +@router.post("") +async def create_end_user( + data: CreateEndUserRequest, + db: Session = Depends(get_db), +): + """ + Create an end user. + + Creates a new end user for the given workspace. + If an end user with the same other_id already exists in the workspace, + returns the existing one. + """ + logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}") + + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_or_create_end_user( + app_id=None, + workspace_id=data.workspace_id, + other_id=data.other_id, + ) + + logger.info(f"End user ready: {end_user.id}") + + result = { + "id": str(end_user.id), + "other_id": end_user.other_id or "", + "other_name": end_user.other_name or "", + "workspace_id": str(end_user.workspace_id), + } + + return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index ff284f39..55149cce 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -91,7 +91,7 @@ async def upload_file( if file_size > settings.MAX_FILE_SIZE: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit" ) @@ -172,7 +172,6 @@ async def upload_file_with_share_token( # Get share and release info from share_token service = ReleaseShareService(db) - share_info = service.get_shared_release_info(share_token=share_data.share_token) # Get share object to access app_id share = service.repo.get_by_share_token(share_data.share_token) @@ -499,6 +498,51 @@ async def get_file_url( ) +@router.get("/files/{file_id}/public-url", response_model=ApiResponse) +async def get_permanent_file_url( + file_id: uuid.UUID, + db: Session = Depends(get_db), + storage_service: FileStorageService = Depends(get_file_storage_service), +): + """ + 获取文件的永久公开 URL(无过期时间)。 + + - 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置) + - 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限) + """ + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist") + + if file_metadata.status != "completed": + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File upload not completed, status: {file_metadata.status}") + + file_key = file_metadata.file_key + storage = storage_service.storage + + try: + if isinstance(storage, LocalStorage): + url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}" + else: + url = await storage.get_permanent_url(file_key) + if not url: + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Permanent URL not supported for current storage backend") + + api_logger.info(f"Generated permanent URL: file_id={file_id}") + return success( + data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name}, + msg="Permanent file URL generated successfully" + ) + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to generate permanent URL: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to generate permanent URL: {str(e)}") + + @router.get("/public/{file_id}", response_model=Any) async def public_download_file( request: Request, diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 2c979435..cc0efab3 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -195,10 +195,9 @@ async def get_workspace_end_users( api_logger.warning(f"Redis 缓存写入失败: {str(e)}") # 触发社区聚类补全任务(异步,不阻塞接口响应) - # 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类 try: from app.tasks import init_community_clustering_for_users - init_community_clustering_for_users.delay(end_user_ids=end_user_ids) + init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id)) api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}") except Exception as e: api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") diff --git a/api/app/controllers/memory_working_controller.py b/api/app/controllers/memory_working_controller.py index 8aab039a..c06fd432 100644 --- a/api/app/controllers/memory_working_controller.py +++ b/api/app/controllers/memory_working_controller.py @@ -33,35 +33,47 @@ def get_memory_count( @router.get("/{end_user_id}/conversations", response_model=ApiResponse) def get_conversations( end_user_id: uuid.UUID, + page: int = 1, + pagesize: int = 20, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ - Retrieve all conversations for the current user in a specific group. + Retrieve conversations for the current user in a specific group with pagination. Args: end_user_id (UUID): The group identifier. + page (int): Page number (1-based). Defaults to 1. + pagesize (int): Number of items per page. Defaults to 20. current_user (User, optional): The authenticated user. db (Session, optional): SQLAlchemy session. Returns: - ApiResponse: Contains a list of conversation IDs. - - Notes: - - Initializes the ConversationService with the current DB session. - - Returns only conversation IDs for lightweight response. - - Logs can be added to trace requests in production. + ApiResponse: Contains a paginated list of conversations. """ + page = max(1, page) + page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100 conversation_service = ConversationService(db) - conversations = conversation_service.get_user_conversations( - end_user_id + conversations, total = conversation_service.get_user_conversations( + end_user_id, + page=page, + page_size=page_size ) - return success(data=[ - { - "id": conversation.id, - "title": conversation.title - } for conversation in conversations - ], msg="get conversations success") + return success(data={ + "items": [ + { + "id": conversation.id, + "title": conversation.title + } for conversation in conversations + ], + "total": total, + "page": { + "page": page, + "pagesize": page_size, + "total": total, + "hasnext": (page * page_size) < total + }, + }, msg="get conversations success") @router.get("/{end_user_id}/messages", response_model=ApiResponse) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 34489e8a..08a94a89 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -6,6 +6,7 @@ from app.core.response_utils import success from app.db import get_db from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.memory_api_schema import ( + ListConfigsResponse, MemoryReadRequest, MemoryReadResponse, MemoryWriteRequest, @@ -31,14 +32,15 @@ async def write_memory_api_service( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), - payload: MemoryWriteRequest = Body(..., embed=False), - + message: str = Body(..., description="Message content"), ): """ Write memory to storage. Stores memory content for the specified end user using the Memory API Service. """ + body = await request.json() + payload = MemoryWriteRequest(**body) logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") memory_api_service = MemoryAPIService(db) @@ -62,13 +64,15 @@ async def read_memory_api_service( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), - payload: MemoryReadRequest = Body(..., embed=False), + message: str = Body(..., description="Query message"), ): """ Read memory from storage. Queries and retrieves memories for the specified end user with context-aware responses. """ + body = await request.json() + payload = MemoryReadRequest(**body) logger.info(f"Memory read request - end_user_id: {payload.end_user_id}") memory_api_service = MemoryAPIService(db) @@ -85,3 +89,27 @@ async def read_memory_api_service( logger.info(f"Memory read successful for end_user: {payload.end_user_id}") return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully") + + +@router.get("/configs") +@require_api_key(scopes=["memory"]) +async def list_memory_configs( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + List all memory configs for the workspace. + + Returns all available memory configurations associated with the authorized workspace. + """ + logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}") + + memory_api_service = MemoryAPIService(db) + + result = memory_api_service.list_memory_configs( + workspace_id=api_key_auth.workspace_id, + ) + + logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") + return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 5563b9d7..74b8d88e 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -76,6 +76,8 @@ async def get_tool_methods( if methods is None: raise HTTPException(status_code=404, detail="工具不存在") return success(data=methods, msg="获取工具方法成功") + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -121,6 +123,8 @@ async def create_tool( raise HTTPException(status_code=400, detail=e.message) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -149,6 +153,8 @@ async def update_tool( return success(msg="工具更新成功") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -191,6 +197,8 @@ async def set_tool_active( return success(msg=f"工具已{action}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -223,6 +231,8 @@ async def execute_tool( }, msg="工具执行完成" ) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/app/core/config.py b/api/app/core/config.py index cdaa13cc..4a944557 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -97,6 +97,7 @@ class Settings: # File Upload MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800")) + MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20")) FILE_PATH: str = os.getenv("FILE_PATH", "/files") FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600")) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 02aa1b44..b62eb50a 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -166,15 +166,12 @@ async def write( statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, connector=neo4j_connector, - config_id=config_id, - llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, ) if success: logger.info("Successfully saved all data to Neo4j") # 写入成功后,异步触发聚类(不阻塞写入响应) schedule_clustering_after_write( all_entity_nodes, - config_id=config_id, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 21257f2e..d9c04f8b 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -69,15 +69,15 @@ class LabelPropagationEngine: def __init__( self, connector: Neo4jConnector, - config_id: Optional[str] = None, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ): self.connector = connector self.repo = CommunityRepository(connector) - self.config_id = config_id self.llm_model_id = llm_model_id self.embedding_model_id = embedding_model_id + self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -439,15 +439,17 @@ class LabelPropagationEngine: @staticmethod def _build_entity_lines(members: List[Dict]) -> List[str]: - """将实体列表格式化为 prompt 行,包含 name、aliases、description。""" + """将实体列表格式化为 prompt 行,包含 name、aliases、description、example。""" lines = [] for m in members: m_name = m.get("name", "") aliases = m.get("aliases") or [] description = m.get("description") or "" + example = m.get("example") or "" aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else "" desc_str = f":{description}" if description else "" - lines.append(f"- {m_name}{aliases_str}{desc_str}") + example_str = f"(示例:{example})" if example else "" + lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}") return lines async def _generate_community_metadata( @@ -481,11 +483,24 @@ class LabelPropagationEngine: core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] entity_list_str = "\n".join(self._build_entity_lines(members)) + + # 方案四:注入社区内实体间关系三元组 + relationships = await self.repo.get_community_relationships(cid, end_user_id) + rel_lines = [ + f"- {r['subject']} → {r['predicate']} → {r['object']}" + for r in relationships + if r.get("subject") and r.get("predicate") and r.get("object") + ] + rel_section = ( + f"\n实体间关系:\n" + "\n".join(rel_lines) + if rel_lines else "" + ) + prompt = ( - f"以下是一组语义相关的实体:\n{entity_list_str}\n\n" + f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" f"请为这组实体所代表的主题:\n" f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过50个字)\n\n" + f"2. 写一句话摘要(不超过80个字)\n\n" f"严格按以下格式输出,不要有其他内容:\n" f"名称:<名称>\n摘要:<摘要>" ) diff --git a/api/app/core/storage/base.py b/api/app/core/storage/base.py index 6653d04a..8ab0fcde 100644 --- a/api/app/core/storage/base.py +++ b/api/app/core/storage/base.py @@ -121,3 +121,18 @@ class StorageBackend(ABC): URL for accessing the file. """ pass + + async def get_permanent_url(self, file_key: str) -> Optional[str]: + """ + Get a permanent public URL for the file (no expiration). + + Returns None by default; remote storage backends should override this + if the bucket is configured for public read access. + + Args: + file_key: Unique identifier for the file in the storage system. + + Returns: + A permanent public URL, or None if not supported. + """ + return None diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 81bedce1..27669ffa 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -261,3 +261,13 @@ class OSSStorage(StorageBackend): logger.error(f"Failed to generate presigned URL for {file_key}: {e}") # Return a basic URL format as fallback return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" + + async def get_permanent_url(self, file_key: str) -> str: + """ + Get a permanent public URL for the file (requires bucket public read). + + Returns: + A permanent URL in the format: https://{bucket}.{endpoint}/{file_key} + """ + host = self.endpoint.replace("https://", "").replace("http://", "") + return f"https://{self.bucket_name}.{host}/{file_key}" diff --git a/api/app/core/storage/s3.py b/api/app/core/storage/s3.py index 37ad4184..c7b33ffe 100644 --- a/api/app/core/storage/s3.py +++ b/api/app/core/storage/s3.py @@ -378,3 +378,12 @@ class S3Storage(StorageBackend): logger.error(f"Failed to generate presigned URL for {file_key}: {e}") # Return a basic URL format as fallback return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" + + async def get_permanent_url(self, file_key: str) -> str: + """ + Get a permanent public URL for the file (requires bucket public read). + + Returns: + A permanent URL in the format: https://{bucket}.s3.{region}.amazonaws.com/{file_key} + """ + return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 90668ad9..674c45d0 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes import NodeFactory from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES from app.core.workflow.utils.expression_evaluator import evaluate_condition +from app.core.workflow.validator import WorkflowValidator logger = logging.getLogger(__name__) +# Regex to split output into: +# - variable placeholders: {{ ... }} +# - normal literal text +# +# Example: +# "Hello {{user.name}}!" -> +# ["Hello ", "{{user.name}}", "!"] +_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+') +# Strict variable format: {{ node_id.field_name }} +_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}') + class GraphBuilder: def __init__( @@ -37,13 +49,13 @@ class GraphBuilder: self.stream = stream self.subgraph = subgraph - self.start_node_id = None - self.end_node_ids = [] + self.start_node_id: str | None = None + self.node_map = {node["id"]: node for node in self.nodes} self.end_node_map: dict[str, StreamOutputConfig] = {} - self._find_upstream_branch_node = lru_cache( + self._find_upstream_activation_dep = lru_cache( maxsize=len(self.nodes) * 2 - )(self._find_upstream_branch_node) + )(self._find_upstream_activation_dep) if variable_pool: self.variable_pool = variable_pool else: @@ -51,10 +63,19 @@ class GraphBuilder: self.graph = StateGraph(WorkflowState) self.add_nodes() + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) + self.end_nodes = [ + node + for node in self.nodes + if node.get("type") == "end" and node.get("id") in self.reachable_nodes + ] self.add_edges() - self._analyze_end_node_output() # EDGES MUST BE ADDED AFTER NODES ARE ADDED. + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) + self._build_reverse_adj() + self._analyze_end_node_output() + @property def nodes(self) -> list[dict[str, Any]]: return self.workflow_config.get("nodes", []) @@ -87,60 +108,50 @@ class GraphBuilder: result[node[0]].append(node[1]) return result - def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]: - """ - Recursively find all upstream branch (control) nodes that influence the execution - of the given target node. + def _build_reverse_adj(self): + for edge in self.edges: + if edge["source"] not in self.reachable_nodes: + continue + self._reverse_adj[edge.get("target")].append({ + "id": edge["source"], "branch": edge.get("label") + }) - This method walks upstream along the workflow graph starting from `target_node`. - It distinguishes between: - - branch nodes (node types listed in `BRANCH_NODES`) - - non-branch nodes (ordinary processing nodes) + def _find_upstream_activation_dep( + self, + target_node: str + ) -> tuple[tuple[tuple[str, str]], tuple[str]]: + """Find upstream dependencies that affect the activation of a target node. - Traversal rules: - 1. For each immediate upstream node: - - If it is a branch node, it is recorded as an affecting control node. - - If it is a non-branch node, the traversal continues recursively upstream. - 2. If ANY upstream path reaches a START / CYCLE_START node without encountering - a branch node, the traversal is considered invalid: - - `has_branch` will be False - - no branch nodes are returned. - 3. Only when ALL upstream non-branch paths eventually lead to at least one - branch node will `has_branch` be True. + Walks upstream along the workflow graph from the target node, collecting + two types of dependencies: + - Branch control nodes: upstream branch nodes (e.g. if-else) whose + routing outcome determines whether the target node executes. + - Output nodes: upstream END nodes that must complete their output + before the target node can activate. - Special case: - - If `target_node` has no upstream nodes AND its type is START or CYCLE_START, - it is considered directly reachable from the workflow entry, and therefore - has no controlling branch nodes. + The traversal terminates early and returns empty tuples if any upstream + path reaches START/CYCLE_START without encountering a branch or output + node, indicating the target node is directly reachable and should be + activated immediately. Args: - target_node (str): - The identifier of the node whose upstream control branches - are to be resolved. + target_node: The ID of the node whose upstream activation + dependencies are to be resolved. Returns: - tuple[bool, tuple[tuple[str, str]]]: - - has_branch (bool): - True if every upstream path from `target_node` encounters - at least one branch node. - False if any path reaches a start node without a branch. - - branch_nodes (tuple[tuple[str, str]]): - A deduplicated tuple of `(branch_node_id, branch_label)` pairs - representing all branch nodes that can influence `target_node`. - Returns an empty tuple if `has_branch` is False. + A tuple of two elements: + - A deduplicated tuple of (branch_node_id, branch_label) pairs + representing upstream branch control dependencies. Empty if + any clean path to START exists. + - A deduplicated tuple of upstream output node IDs that must + complete before this node activates. """ - source_nodes = [ - { - "id": edge.get("source"), - "branch": edge.get("label") - } - for edge in self.edges - if edge.get("target") == target_node - ] + source_nodes = self._reverse_adj[target_node] if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: - return False, tuple() + return tuple(), tuple() branch_nodes = [] + output_nodes = [] non_branch_nodes = [] for node_info in source_nodes: @@ -149,19 +160,23 @@ class GraphBuilder: (node_info["id"], node_info["branch"]) ) else: + if self.get_node_type(node_info["id"]) == NodeType.END: + output_nodes.append(node_info["id"]) non_branch_nodes.append(node_info["id"]) has_branch = True for node_id in non_branch_nodes: - node_has_branch, nodes = self._find_upstream_branch_node(node_id) - has_branch = has_branch and node_has_branch - if not has_branch: - break - branch_nodes.extend(nodes) - if not has_branch: - branch_nodes = [] + upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id) + if not upstream_control_nodes: + if not upstream_output_nodes and node_id not in output_nodes: + return tuple(), tuple() + branch_nodes = [] + has_branch = False + if has_branch: + branch_nodes.extend(upstream_control_nodes) + output_nodes.extend(upstream_output_nodes) - return has_branch, tuple(set(branch_nodes)) + return tuple(set(branch_nodes)), tuple(set(output_nodes)) def _analyze_end_node_output(self): """ @@ -182,11 +197,10 @@ class GraphBuilder: """ # Collect all End nodes in the workflow - end_nodes = [node for node in self.nodes if node.get("type") == "end"] - logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes") + logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes") # Iterate through each End node to analyze its output - for end_node in end_nodes: + for end_node in self.end_nodes: end_node_id = end_node.get("id") config = end_node.get("config", {}) output = config.get("output") @@ -195,42 +209,33 @@ class GraphBuilder: if not output: continue - # Regex to split output into: - # - variable placeholders: {{ ... }} - # - normal literal text - # - # Example: - # "Hello {{user.name}}!" -> - # ["Hello ", "{{user.name}}", "!"] - pattern = r'\{\{.*?\}\}|[^{}]+' - - # Strict variable format: {{ node_id.field_name }} - variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}' - variable_pattern = re.compile(variable_pattern_string) - # Split output into ordered segments - output_template = list(re.findall(pattern, output)) + output_template = list(_OUTPUT_PATTERN.findall(output)) # Determine whether each segment is literal text # True -> literal (can be directly output) # False -> variable placeholder (needs runtime value) output_flag = [ - not bool(variable_pattern.match(item)) + not bool(_VARIABLE_PATTERN.match(item)) for item in output_template ] # Stream mode: output activation depends on upstream branch nodes if self.stream: # Find upstream branch nodes that can control this End node - has_branch, control_nodes = self._find_upstream_branch_node(end_node_id) - + upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id) + activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes) # Build StreamOutputConfig for this End node self.end_node_map[end_node_id] = StreamOutputConfig( + id=end_node_id, # If there is no upstream branch, output is active immediately - activate=not has_branch, + activate=activate, # Branch nodes that control activation of this End node - control_nodes=self._merge_control_nodes(control_nodes), + control_nodes=self._merge_control_nodes(upstream_control_nodes), + upstream_output_nodes=list(upstream_output_nodes), + control_resolved=not bool(upstream_control_nodes), + output_resolved=not bool(upstream_output_nodes), # Convert output segments into OutputContent objects outputs=list( @@ -249,14 +254,16 @@ class GraphBuilder: cursor=0 ) logger.info(f"[Stream Analysis] end_id: {end_node_id}, " - f"activate: {not has_branch}, " - f"control_nodes: {control_nodes}," + f"activate: {activate}, " + f"control_nodes: {upstream_control_nodes}," + f"ref_outputs: {upstream_output_nodes}," f"output: {output_template}," f"output_activate: {output_flag}") # Non-stream mode: all outputs are activated by default else: self.end_node_map[end_node_id] = StreamOutputConfig( + id=end_node_id, activate=True, control_nodes={}, outputs=list( @@ -269,7 +276,10 @@ class GraphBuilder: for output_string, activate in zip(output_template, output_flag) ] ), - cursor=0 + cursor=0, + upstream_output_nodes=[], + control_resolved=True, + output_resolved=True, ) def add_nodes(self): @@ -304,8 +314,6 @@ class GraphBuilder: # Record start and end node IDs if node_type in [NodeType.START, NodeType.CYCLE_START]: self.start_node_id = node_id - elif node_type == NodeType.END: - self.end_node_ids.append(node_id) # Create node instance (start and end nodes are also created) # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph @@ -448,7 +456,7 @@ class GraphBuilder: branch_activate = [] new_state = state.copy() new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate - node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False) + node_output = variable_pool.get_node_output(src, default=dict(), strict=False) for label, branch in unique_branch.items(): if node_output and evaluate_condition( branch["condition"], @@ -494,9 +502,11 @@ class GraphBuilder: logger.debug(f"Added waiting edge: {sources} -> {target}") # Connect End nodes to the global END node - for end_node_id in self.end_node_ids: - self.graph.add_edge(end_node_id, END) - logger.debug(f"Added edge: {end_node_id} -> END") + for end_node in self.end_nodes: + end_node_id = end_node.get("id") + if end_node_id: + self.graph.add_edge(end_node_id, END) + logger.debug(f"Added edge: {end_node_id} -> END") return def build(self) -> CompiledStateGraph: diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py index 31bccf57..e5a03c1c 100644 --- a/api/app/core/workflow/engine/result_builder.py +++ b/api/app/core/workflow/engine/result_builder.py @@ -12,6 +12,7 @@ class WorkflowResultBuilder: variable_pool: VariablePool, elapsed_time: float, final_output: str, + success: bool ): """Construct the final standardized output of the workflow execution. @@ -29,6 +30,7 @@ class WorkflowResultBuilder: elapsed_time (float): Total execution time in seconds. final_output (Any): The aggregated or final output content of the workflow (e.g., combined messages from all End nodes). + success (bool): Whether the execution was successful. Returns: dict: A dictionary containing the final workflow execution result with keys: @@ -49,7 +51,7 @@ class WorkflowResultBuilder: conversation_id = variable_pool.get_value("sys.conversation_id") return { - "status": "completed", + "status": "completed" if success else "failed", "output": final_output, "variables": { "conv": variable_pool.get_all_conversation_vars(), diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index ddee9adc..6685a49e 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -3,6 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 15:11 import re +from queue import Queue from typing import AsyncGenerator from pydantic import BaseModel, Field, PrivateAttr @@ -37,8 +38,8 @@ class OutputContent(BaseModel): activate: bool = Field( ..., description=( - "Whether this output segment is currently active.\n" - "- True: allowed to be emitted/output\n" + "Whether this output segment is currently active." + "- True: allowed to be emitted/output" "- False: blocked until activated by branch control" ) ) @@ -46,8 +47,8 @@ class OutputContent(BaseModel): is_variable: bool = Field( ..., description=( - "Whether this segment represents a variable placeholder.\n" - "True -> variable (e.g. {{ node.field }})\n" + "Whether this segment represents a variable placeholder." + "True -> variable (e.g. {{ node.field }})" "False -> literal text" ) ) @@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel): - which upstream branch/control nodes gate the activation - how each parsed output segment is streamed and activated """ + id: str = Field( + ..., + description="ID of the End node this configuration belongs to." + ) activate: bool = Field( ..., description=( - "Global activation flag for the End node output.\n" - "When False, output segments should not be emitted even if available.\n" + "Global activation flag for the End node output." + "When False, output segments should not be emitted even if available." "This flag typically becomes True once required control branch conditions " "are satisfied." ) @@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel): control_nodes: dict[str, list[str]] = Field( ..., description=( - "Control branch conditions for this End node output.\n" - "Mapping of `branch_node_id -> expected_branch_label`.\n" + "Control branch conditions for this End node output." + "Mapping of `branch_node_id -> expected_branch_label`." "The End node output becomes globally active when a controlling branch node " "reports a matching completion status." ) ) + upstream_output_nodes: list[str] = Field( + ..., + description=( + "Upstream output node dependencies (data flow)." + "Represents END/output nodes that this output depends on." + "These nodes provide data sources required before this output can be activated " + "or streamed." + "Used to ensure correct ordering and dependency resolution in streaming mode." + ) + ) + + control_resolved: bool = Field( + ..., + description=( + "Whether all upstream branch control dependencies have been satisfied." + "True if no upstream branch nodes exist or the required branch " + "conditions have been met." + ) + ) + + output_resolved: bool = Field( + ..., + description=( + "Whether all upstream output node dependencies have been completed." + "True if no upstream output nodes exist or all upstream output " + "nodes have finished their output." + ) + ) + outputs: list[OutputContent] = Field( ..., description=( - "Ordered list of output segments parsed from the output template.\n" + "Ordered list of output segments parsed from the output template." "Each segment represents either a literal text block or a variable placeholder " "that may be activated independently." ) @@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel): cursor: int = Field( ..., description=( - "Streaming cursor index.\n" - "Indicates the next output segment index to be emitted.\n" + "Streaming cursor index." + "Indicates the next output segment index to be emitted." "Segments with index < cursor are considered already streamed." ) ) + force: bool = Field( + default=False, + description=( + "Force flag for output emission." + "When True, all output segments are emitted regardless of activation state." + "Triggered when this output node has finished execution." + ) + ) + def update_activate(self, scope: str, status=None): """ - Update streaming activation state based on an upstream node or special variable. + Update streaming activation state based on upstream events. Args: scope (str): Identifier of the completed upstream entity. - If a control branch node, it should match a key in `control_nodes`. - - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments. + - If an upstream output node, it should match an entry in `upstream_output_nodes`. + - If a variable placeholder (e.g., "sys.xxx" or "node_id.field"), + it may appear in output segments. + status (optional): Completion status of the control branch node. Required when `scope` refers to a control node. Behavior: - 1. Control branch nodes: - - If `scope` matches a key in `control_nodes` and `status` matches the expected - branch label, the End node output becomes globally active (`activate = True`). + 1. Force activation: + - If `self.force` is True, the method returns immediately. + - If `scope == self.id`, the node marks itself as completed: + - `activate = True` + - `force = True` + This is typically used for final flushing when the node finishes execution. - 2. Variable output segments: - - For each segment that is a variable (`is_variable=True`): - - If the segment literal references `scope`, mark the segment as active. - - This applies both to regular node variables (e.g., "node_id.field") - and special system variables (e.g., "sys.xxx"). + 2. Control dependency resolution: + - If `scope` matches a key in `control_nodes`: + - `status` must be provided. + - If `status` matches expected branch labels, mark control as resolved + (`control_resolved = True`). + + 3. Upstream output dependency resolution: + - If `scope` is in `upstream_output_nodes`, + mark data dependency as resolved (`output_resolved = True`). + + 4. Global activation condition: + - The node becomes active when BOTH conditions are satisfied: + - control_resolved == True + - output_resolved == True + - Once activated, `activate` remains True. + + 5. Variable segment activation: + - For each output segment that is a variable (`is_variable=True`): + - If the segment depends on the given `scope`, + mark the segment as active. + - This applies to both node variables (e.g., "node_id.field") + and system variables (e.g., "sys.xxx"). Notes: - - This method does not emit output or advance the streaming cursor. - - It only updates activation flags based on upstream events or special variables. + - This method does NOT emit output or advance the streaming cursor. + - It only updates activation and dependency resolution states. + - Activation is driven by both control flow (branch nodes) and + data flow (upstream output nodes). """ + if self.force: + return - # Case 1: resolve control branch dependency + if scope == self.id: + self.activate = True + self.force = True + return + + # resolve control branch dependency if scope in self.control_nodes: if status is None: raise RuntimeError("[Stream Output] Control node activation status not provided") if status in self.control_nodes[scope]: - self.activate = True + self.control_resolved = True - # Case 2: activate variable segments related to this node + if scope in self.upstream_output_nodes: + self.upstream_output_nodes.remove(scope) + if not self.upstream_output_nodes: + self.output_resolved = True + + self.activate = self.activate or (self.control_resolved and self.output_resolved) + + # activate variable segments related to this node for i in range(len(self.outputs)): if ( self.outputs[i].is_variable @@ -174,12 +256,17 @@ class StreamOutputCoordinator: def __init__(self): self.end_outputs: dict[str, StreamOutputConfig] = {} self.activate_end: str | None = None + self.output_queue: Queue = Queue() + self.processed_outputs = [] def initialize_end_outputs( self, end_node_map: dict[str, StreamOutputConfig] ): self.end_outputs = end_node_map + self.processed_outputs = [] + self.activate_end = None + self.output_queue = Queue() @property def current_activate_end_info(self): @@ -211,8 +298,11 @@ class StreamOutputCoordinator: """ for node in self.end_outputs.keys(): self.end_outputs[node].update_activate(scope, status) - if self.end_outputs[node].activate and self.activate_end is None: - self.activate_end = node + if self.end_outputs[node].activate and node not in self.processed_outputs: + self.output_queue.put(node) + self.processed_outputs.append(node) + if self.activate_end is None and not self.output_queue.empty(): + self.activate_end = self.output_queue.get_nowait() async def emit_activate_chunk( self, @@ -256,7 +346,7 @@ class StreamOutputCoordinator: final_chunk = '' current_segment = end_info.outputs[end_info.cursor] - if not current_segment.activate and not force: + if not current_segment.activate and not force and not end_info.force: # Stop processing until this segment becomes active break @@ -273,7 +363,7 @@ class StreamOutputCoordinator: logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}") if final_chunk: - logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}") + logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}") yield { "event": "message", "data": { @@ -285,8 +375,7 @@ class StreamOutputCoordinator: end_info.cursor += 1 if end_info.cursor >= len(end_info.outputs): - self.end_outputs.pop(self.activate_end) - self.activate_end = None + self.pop_current_activate_end() async def flush_remaining_chunk( self, @@ -325,6 +414,8 @@ class StreamOutputCoordinator: async for msg_event in self.emit_activate_chunk(variable_pool, force=True): yield msg_event + if not self.output_queue.empty(): + self.activate_end = self.output_queue.get_nowait() # Move to next active End node if current one is done if not self.activate_end and self.end_outputs: self.activate_end = list(self.end_outputs.keys())[0] diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index bc88df19..cf6f4a7b 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -351,12 +351,12 @@ class VariablePool: } return runtime_vars - def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: + def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None: """获取指定节点的输出(运行时变量) Args: node_id: 节点 ID - defalut: 默认值 + default: 默认值 strict: 是否严格模式 Returns: @@ -368,7 +368,7 @@ class VariablePool: if strict: raise KeyError(f"node {node_id} output not exist") else: - return defalut + return default def copy(self, pool: 'VariablePool'): self.variables = deepcopy(pool.variables) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index ff979f2b..c9ed6e65 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -128,89 +128,100 @@ class WorkflowExecutor: - token_usage: aggregated token usage if available - error: error message if any """ - logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}") - - start_time = datetime.datetime.now() - - # Execute the workflow - try: - # Build the workflow graph - graph = self.build_graph() - - # Initialize the variable pool with input data - await self.variable_initializer.initialize( - variable_pool=self.variable_pool, - input_data=input_data, - execution_context=self.execution_context - ) - initial_state = self.state_manager.create_initial_state( - workflow_config=self.workflow_config, - input_data=input_data, - execution_context=self.execution_context, - start_node_id=self.start_node_id - ) - - result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) - - # Aggregate output from all End nodes - full_content = '' - for end_id in self.stream_coordinator.end_outputs.keys(): - full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) - - # Append messages for user and assistant - if input_data.get("files"): - result["messages"].extend( - [ - { - "role": "user", - "content": input_data.get("message", '') - }, - { - "role": "user", - "content": input_data.get("files") - }, - { - "role": "assistant", - "content": full_content - } - ] - ) - else: - result["messages"].extend( - [ - { - "role": "user", - "content": input_data.get("message", '') - }, - { - "role": "assistant", - "content": full_content - } - ] - ) - # Calculate elapsed time - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - - logger.info( - f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") - - return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) - - except Exception as e: - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - - logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", - exc_info=True) - return { - "status": "failed", - "error": str(e), - "output": None, - "node_outputs": {}, - "elapsed_time": elapsed_time, - "token_usage": None - } + start = datetime.datetime.now() + async for event in self.execute_stream(input_data): + if event.get("event") == "workflow_end": + return event.get("data") + return self.result_builder.build_final_output( + {"error": "Workflow execution did not end as expected"}, + self.variable_pool, + (datetime.datetime.now() - start).total_seconds(), + "", + success=False + ) + # logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}") + # + # start_time = datetime.datetime.now() + # + # # Execute the workflow + # try: + # # Build the workflow graph + # graph = self.build_graph() + # + # # Initialize the variable pool with input data + # await self.variable_initializer.initialize( + # variable_pool=self.variable_pool, + # input_data=input_data, + # execution_context=self.execution_context + # ) + # initial_state = self.state_manager.create_initial_state( + # workflow_config=self.workflow_config, + # input_data=input_data, + # execution_context=self.execution_context, + # start_node_id=self.start_node_id + # ) + # + # result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) + # + # # Aggregate output from all End nodes + # full_content = '' + # for end_id in self.stream_coordinator.end_outputs.keys(): + # full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) + # + # # Append messages for user and assistant + # if input_data.get("files"): + # result["messages"].extend( + # [ + # { + # "role": "user", + # "content": input_data.get("message", '') + # }, + # { + # "role": "user", + # "content": input_data.get("files") + # }, + # { + # "role": "assistant", + # "content": full_content + # } + # ] + # ) + # else: + # result["messages"].extend( + # [ + # { + # "role": "user", + # "content": input_data.get("message", '') + # }, + # { + # "role": "assistant", + # "content": full_content + # } + # ] + # ) + # # Calculate elapsed time + # end_time = datetime.datetime.now() + # elapsed_time = (end_time - start_time).total_seconds() + # + # logger.info( + # f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") + # + # return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) + # + # except Exception as e: + # end_time = datetime.datetime.now() + # elapsed_time = (end_time - start_time).total_seconds() + # + # logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", + # exc_info=True) + # return { + # "status": "failed", + # "error": str(e), + # "output": None, + # "node_outputs": {}, + # "elapsed_time": elapsed_time, + # "token_usage": None + # } async def execute_stream( self, @@ -248,7 +259,8 @@ class WorkflowExecutor: "timestamp": int(start_time.timestamp() * 1000) } } - + result = None + full_content = '' try: # Build the workflow graph in streaming mode graph = self.build_graph(stream=True) @@ -266,7 +278,6 @@ class WorkflowExecutor: start_node_id=self.start_node_id ) - full_content = '' self.stream_coordinator.update_scope_activation("sys") # Execute the workflow with streaming @@ -363,7 +374,12 @@ class WorkflowExecutor: yield { "event": "workflow_end", - "data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) + "data": self.result_builder.build_final_output( + result, + self.variable_pool, + elapsed_time, + full_content, + success=True) } except Exception as e: @@ -372,16 +388,19 @@ class WorkflowExecutor: logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", exc_info=True) - + if result is None: + result = {"error": str(e)} + else: + result["error"] = str(e) yield { "event": "workflow_end", - "data": { - "execution_id": self.execution_context.execution_id, - "status": "failed", - "error": str(e), - "elapsed_time": elapsed_time, - "timestamp": end_time.isoformat() - } + "data": self.result_builder.build_final_output( + result, + self.variable_pool, + elapsed_time, + full_content, + success=False + ) } diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 9303302d..1e055002 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -128,7 +128,7 @@ class CodeNode(BaseNode): else: raise ValueError(f"Unsupported language: {self.typed_config.language}") - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=60) as client: response = await client.post( "http://sandbox:8194/v1/sandbox/run", headers={ diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index 52aca1d9..75358c47 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -51,7 +51,7 @@ class ConditionDetail(BaseModel): ) right: Any = Field( - ..., + default=None, description="Right-hand operand of the comparison expression" ) diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index d3ada1ec..84901bad 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -158,7 +158,7 @@ class LoopRuntime: self.variable_pool.variables["conv"].update( self.child_variable_pool.variables["conv"] ) - loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) + loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) loopstate["node_outputs"][self.node_id] = loop_vars def evaluate_conditional(self) -> bool: @@ -261,4 +261,4 @@ class LoopRuntime: idx += 1 logger.info(f"loop node {self.node_id}: execution completed") - return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} + return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 894898f0..638e4b2d 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -18,7 +18,7 @@ class ConditionDetail(BaseModel): ) right: Any = Field( - ..., + default=None, description="Value to compare with" ) diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 7e98efab..5d2bdf9a 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -31,13 +31,13 @@ class IfElseNode(BaseNode): expressions.append({ "left": self.get_variable(expression.left, variable_pool, strict=False), "right": expression.right - if expression.input_type == ValueInputType.CONSTANT + if expression.input_type == ValueInputType.CONSTANT or expression.right is None else self.get_variable(expression.right, variable_pool, strict=False), - "operator": expression.operator, + "operator": str(expression.operator), }) result.append({ "expressions": expressions, - "logical_operator": case.logical_operator, + "logical_operator": str(case.logical_operator), }) return { "cases": result diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 14f789a9..d3e9efd9 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -5,7 +5,7 @@ from typing import Any from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode @@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: KnowledgeRetrievalNodeConfig | None = None + self.vector_service: ElasticSearchVector | None = None def _output_types(self) -> dict[str, VariableType]: return { @@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker + def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): + if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: + children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + for child in children: + if not (child and child.chunk_num > 0 and child.status == 1): + continue + kb_config.kb_id = child.id + self.knowledge_retrieval(db, query, rs, child, kb_config) + return + self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + indices = f"Vector_index_{kb_config.kb_id}_Node".lower() + match kb_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold)) + case RetrieveType.SEMANTIC: + rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight)) + case RetrieveType.HYBRID: + rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight) + rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold) + + # Deduplicate hybrid retrieval results + unique_rs = self._deduplicate_docs(rs1, rs2) + if not unique_rs: + return + if self.typed_config.reranker_id: + self.vector_service.reranker = self.get_reranker_model() + rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + else: + rs.extend(sorted( + unique_rs, + key=lambda d: d.metadata.get("score", 0), + reverse=True + )[:kb_config.top_k]) + case _: + raise RuntimeError("Unknown retrieval type") + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the knowledge retrieval workflow node. @@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode): query = self._render_template(self.typed_config.query, variable_pool) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases - existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases]) - - if not existing_ids: - raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") rs = [] for kb_config in knowledge_bases: db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) if not db_knowledge: raise RuntimeError("The knowledge base does not exist or access is denied.") + self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - indices = f"Vector_index_{kb_config.kb_id}_Node".lower() - match kb_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold)) - case RetrieveType.SEMANTIC: - rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight)) - case RetrieveType.HYBRID: - rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold) - - # Deduplicate hy brid retrieval results - unique_rs = self._deduplicate_docs(rs1, rs2) - if not unique_rs: - continue - if self.typed_config.reranker_id: - vector_service.reranker = self.get_reranker_model() - rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) - else: - rs.extend(sorted( - unique_rs, - key=lambda d: d.metadata.get("score", 0), - reverse=True - )[:kb_config.top_k]) - case _: - raise RuntimeError("Unknown retrieval type") if not rs: return [] if self.typed_config.reranker_id: - vector_service.reranker = self.get_reranker_model() - final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + self.vector_service.reranker = self.get_reranker_model() + final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) else: final_rs = sorted( rs, diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index be33d35a..14fc9d9f 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -250,6 +250,8 @@ class ConditionBase(ABC): self.type_limit = getattr(self, "type_limit", None) def resolve_right_literal_value(self): + if self.right_selector is None: + return None if self.input_type == ValueInputType.VARIABLE: pattern = r"\{\{\s*(.*?)\s*\}\}" right_expression = re.sub(pattern, r"\1", self.right_selector).strip() diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 3b6e9036..fe4aea19 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -170,7 +170,7 @@ class WorkflowValidator: # 仅在发布时验证所有节点可达 # 6. 验证所有节点可达(从 start 节点出发) if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 - reachable = WorkflowValidator._get_reachable_nodes( + reachable = WorkflowValidator.get_reachable_nodes( start_nodes[0]["id"], edges ) @@ -194,7 +194,7 @@ class WorkflowValidator: return len(errors) == 0, errors @staticmethod - def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: + def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: """获取从 start 节点可达的所有节点 Args: diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index aea40cf6..f5d8ff8f 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,7 +2,7 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from app.schemas import FileType @@ -41,10 +41,10 @@ class VariableType(StrEnum): """ if isinstance(var, str): return cls.STRING - elif isinstance(var, (int, float)): - return cls.NUMBER elif isinstance(var, bool): return cls.BOOLEAN + elif isinstance(var, (int, float)): + return cls.NUMBER elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): return cls.FILE elif isinstance(var, dict): @@ -116,7 +116,7 @@ class FileObject(BaseModel): content_cache: dict = Field(default_factory=dict) is_file: bool - _byte_content: bytes | None = None + _byte_content: bytes | None = PrivateAttr(default=None) def get_content(self): return self._byte_content diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 63437fd9..5e8e3f1e 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable) class StringVariable(BaseVariable): + value: str type = 'str' def valid_value(self, value) -> str: @@ -22,6 +23,7 @@ class StringVariable(BaseVariable): class NumberVariable(BaseVariable): + value: int | float type = 'number' def valid_value(self, value) -> int | float: @@ -34,6 +36,7 @@ class NumberVariable(BaseVariable): class BooleanVariable(BaseVariable): + value: bool type = 'boolean' def valid_value(self, value) -> bool: @@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable): class DictVariable(BaseVariable): + value: dict type = 'object' def valid_value(self, value) -> dict: @@ -58,6 +62,7 @@ class DictVariable(BaseVariable): class FileVariable(BaseVariable): + value: FileObject type = 'file' def valid_value(self, value) -> FileObject: @@ -102,6 +107,7 @@ class FileVariable(BaseVariable): class ArrayVariable(BaseVariable, Generic[T]): + value: list[T] type = 'array' def __init__(self, child_type: Type[T], value: list[Any]): @@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]): class NestedArrayVariable(BaseVariable): + value: list[ArrayVariable] type = 'array_nest' def valid_value(self, value: list[T]) -> list[T]: @@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable): category=RuntimeWarning ) class AnyVariable(BaseVariable): + value: Any type = 'any' def valid_value(self, value: Any) -> Any: diff --git a/api/app/db.py b/api/app/db.py index 80ab2756..32261c46 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]: yield db finally: db.rollback() # 只读任务无需 commit + db.close() def get_pool_status(): diff --git a/api/app/main.py b/api/app/main.py index c6256e3c..f4c23ca8 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -506,10 +506,13 @@ async def http_exception_handler(request: Request, exc: HTTPException): 404: "errors.common.not_found", 405: "errors.common.method_not_allowed", 409: "errors.common.conflict", + 413: "errors.common.payload_too_large", 422: "errors.common.validation_failed", 429: "errors.common.too_many_requests", 500: "errors.common.internal_error", + 502: "errors.common.bad_gateway", 503: "errors.common.service_unavailable", + 504: "errors.common.gateway_timeout", } # 如果有对应的翻译键,使用翻译 @@ -534,7 +537,7 @@ async def http_exception_handler(request: Request, exc: HTTPException): return JSONResponse( status_code=exc.status_code, - content=fail(code=exc.status_code, msg=translated_message, error=translated_message) + content=fail(code=exc.status_code, msg=translated_message, error=exc.detail) ) diff --git a/api/app/repositories/conversation_repository.py b/api/app/repositories/conversation_repository.py index eb5d3c61..90f2d6ec 100644 --- a/api/app/repositories/conversation_repository.py +++ b/api/app/repositories/conversation_repository.py @@ -90,27 +90,27 @@ class ConversationRepository: self, user_id: uuid.UUID, workspace_id: uuid.UUID = None, - limit: int = 10, - is_activate: bool = True - ) -> list[Conversation]: + is_activate: bool = True, + page: int = 1, + page_size: int = 20 + ) -> tuple[list[Conversation], int]: """ - Retrieve recent conversations for a specific user. + Retrieve recent conversations for a specific user with pagination. This method queries conversations associated with the given user ID, optionally scoped to a specific workspace. Results are ordered by the - most recently updated conversations and limited to a fixed number. + most recently updated conversations. Args: user_id (uuid.UUID): Unique identifier of the user. workspace_id (uuid.UUID, optional): Workspace scope for the query. If provided, only conversations under this workspace will be returned. - limit (int): Maximum number of conversations to return. - Defaults to 10. - is_activate (bool): Convsersation State limit + is_activate (bool): Conversation State limit. + page (int): Page number (1-based). Defaults to 1. + page_size (int): Number of items per page. Defaults to 20. Returns: - list[Conversation]: A list of conversation entities ordered by - last updated time (descending). + tuple[list[Conversation], int]: A list of conversation entities and total count. """ logger.info(f"Fetching conversation by user_id: {user_id}") @@ -122,18 +122,25 @@ class ConversationRepository: if workspace_id: stmt = stmt.where(Conversation.workspace_id == workspace_id) - stmt = stmt.order_by(desc(Conversation.updated_at)) - stmt = stmt.limit(limit) + # Calculate total count + total = int(self.db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) - convsersations = list(self.db.scalars(stmt).all()) + # Apply ordering and pagination + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * page_size).limit(page_size) + + conversations = list(self.db.scalars(stmt).all()) logger.info( "Conversation fetched successfully", extra={ "user_id": str(user_id), "workspace_id": str(workspace_id), + "total": total, } ) - return convsersations + return conversations, total def list_conversations( self, diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index f9c4bd92..7273340e 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -17,12 +17,17 @@ from app.repositories.neo4j.cypher_queries import ( GET_ALL_ENTITY_IDS_FOR_USER, GET_ENTITIES_PAGE, GET_COMMUNITY_MEMBERS, + GET_COMMUNITY_RELATIONSHIPS, GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_ENTITY_NEIGHBORS_BATCH, GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, + GET_INCOMPLETE_COMMUNITIES, + GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING, + CHECK_COMMUNITY_IS_COMPLETE, + CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING, BATCH_UPDATE_COMMUNITY_METADATA, ) @@ -177,7 +182,7 @@ class CommunityRepository: async def get_community_members( self, community_id: str, end_user_id: str ) -> List[Dict]: - """查询社区成员列表。""" + """查询社区成员列表(含 example 字段)。""" try: return await self.connector.execute_query( GET_COMMUNITY_MEMBERS, @@ -188,6 +193,20 @@ class CommunityRepository: logger.error(f"get_community_members failed: {e}") return [] + async def get_community_relationships( + self, community_id: str, end_user_id: str + ) -> List[Dict]: + """查询社区内实体间的关系三元组(subject, predicate, object)。""" + try: + return await self.connector.execute_query( + GET_COMMUNITY_RELATIONSHIPS, + community_id=community_id, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_community_relationships failed: {e}") + return [] + async def get_all_community_members_batch( self, community_ids: List[str], end_user_id: str ) -> Dict[str, List[Dict]]: @@ -234,6 +253,31 @@ class CommunityRepository: logger.error(f"refresh_member_count failed: {e}") return 0 + async def get_incomplete_communities(self, end_user_id: str, check_embedding: bool = False) -> List[str]: + """查询该用户下属性不完整的 Community 节点 ID 列表。 + + Args: + end_user_id: 用户 ID + check_embedding: 为 True 时额外检查 summary_embedding 是否缺失(仅当用户有 embedding 模型配置时传 True) + """ + try: + query = GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING if check_embedding else GET_INCOMPLETE_COMMUNITIES + result = await self.connector.execute_query(query, end_user_id=end_user_id) + return [row["community_id"] for row in result] + except Exception as e: + logger.error(f"get_incomplete_communities failed: {e}") + return [] + + async def is_community_complete(self, community_id: str, end_user_id: str, check_embedding: bool = False) -> bool: + """检查单个社区节点的属性是否完整。""" + try: + query = CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING if check_embedding else CHECK_COMMUNITY_IS_COMPLETE + result = await self.connector.execute_query(query, community_id=community_id, end_user_id=end_user_id) + return result[0]["is_complete"] if result else False + except Exception as e: + logger.error(f"is_community_complete failed: {e}") + return False + async def update_community_metadata( self, community_id: str, @@ -243,7 +287,7 @@ class CommunityRepository: core_entities: List[str], summary_embedding: Optional[List[float]] = None, ) -> bool: - """更新社区的名称、摘要、核心实体列表和摘要向量。""" + """更新社区的名称、摘要、核心实体列表及 summary_embedding。""" try: result = await self.connector.execute_query( UPDATE_COMMUNITY_METADATA, diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 7b027ca9..0cdaeb59 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1137,10 +1137,20 @@ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->( RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, e.importance_score AS importance_score, e.activation_value AS activation_value, e.name_embedding AS name_embedding, - e.aliases AS aliases, e.description AS description + e.aliases AS aliases, e.description AS description, + e.example AS example ORDER BY coalesce(e.activation_value, 0) DESC """ +GET_COMMUNITY_RELATIONSHIPS = """ +MATCH (e1:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) +MATCH (e2:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c) +MATCH (e1)-[r:EXTRACTED_RELATIONSHIP]->(e2) +RETURN e1.name AS subject, r.predicate AS predicate, e2.name AS object +ORDER BY e1.name, r.predicate, e2.name +LIMIT 20 +""" + GET_ALL_COMMUNITY_MEMBERS_BATCH = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) RETURN c.community_id AS community_id, @@ -1316,3 +1326,38 @@ RETURN s.statement AS statement, ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit """ + +CHECK_COMMUNITY_IS_COMPLETE = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL +) AS is_complete +""" + +CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL AND + c.summary_embedding IS NOT NULL +) AS is_complete +""" + +GET_INCOMPLETE_COMMUNITIES = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL + OR c.name = '' OR c.summary = '' +RETURN c.community_id AS community_id +""" + +GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.name = '' + OR c.summary IS NULL OR c.summary = '' + OR c.core_entities IS NULL + OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') +RETURN c.community_id AS community_id +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 29e337f1..34497d5b 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,4 +1,5 @@ import asyncio +import os from typing import List, Optional # 使用新的仓储层 @@ -304,7 +305,6 @@ async def save_dialog_and_statements_to_neo4j( def schedule_clustering_after_write( entity_nodes: List, - config_id: Optional[str] = None, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, ) -> None: @@ -325,13 +325,12 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) async def _trigger_clustering( new_entity_ids: List[str], end_user_id: str, - config_id: Optional[str] = None, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, ) -> None: @@ -343,7 +342,7 @@ async def _trigger_clustering( from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() - engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) + engine = LabelPropagationEngine(connector, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py index b22673e6..4e24faa0 100644 --- a/api/app/repositories/workflow_repository.py +++ b/api/app/repositories/workflow_repository.py @@ -43,6 +43,7 @@ class WorkflowConfigRepository: edges: list[dict[str, Any]], variables: list[dict[str, Any]] | None = None, execution_config: dict[str, Any] | None = None, + features: dict[str, Any] | None = None, triggers: list[dict[str, Any]] | None = None ) -> WorkflowConfig: """创建或更新工作流配置 @@ -53,6 +54,7 @@ class WorkflowConfigRepository: edges: 边列表 variables: 变量列表 execution_config: 执行配置 + features: 功能特性 triggers: 触发器列表 Returns: @@ -82,6 +84,7 @@ class WorkflowConfigRepository: edges=edges, variables=variables or [], execution_config=execution_config or {}, + features=features or {}, triggers=triggers or [] ) self.db.add(config) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 5238b978..1582d862 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -149,18 +149,26 @@ class FileUploadConfig(BaseModel): ) # 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB document_enabled: bool = Field(default=False) - document_max_size_mb: int = Field(default=100) + document_max_size_mb: int = Field(default=50) document_allowed_extensions: List[str] = Field( - default=["pdf", "docx", "xlsx", "txt", "csv", "json", "md"] + default=["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"] ) # 视频文件:MP4/MOV/AVI/WebM,最大 500MB video_enabled: bool = Field(default=False) - video_max_size_mb: int = Field(default=500) + video_max_size_mb: int = Field(default=50) video_allowed_extensions: List[str] = Field( - default=["mp4", "mov"] + default=["mp4"] ) # 最大文件数量 - max_file_count: int = Field(default=5, ge=1, le=20) + max_file_count: int = Field(default=5, ge=1) + + @field_validator("max_file_count") + @classmethod + def validate_max_file_count(cls, v: int) -> int: + from app.core.config import settings + if v > settings.MAX_FILE_COUNT: + raise ValueError(f"max_file_count 不能超过 {settings.MAX_FILE_COUNT}") + return v class OpeningStatementConfig(BaseModel): diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index 98d257c1..84a34e8a 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -21,7 +21,7 @@ class MemoryWriteRequest(BaseModel): """ end_user_id: str = Field(..., description="End user ID (required)") message: str = Field(..., description="Message content to store") - config_id: Optional[str] = Field(None, description="Memory configuration ID") + config_id: str = Field(..., description="Memory configuration ID (required)") storage_type: str = Field("neo4j", description="Storage type: neo4j or rag") user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID") @@ -68,7 +68,7 @@ class MemoryReadRequest(BaseModel): "0", description="Search mode: 0=verify, 1=direct, 2=context" ) - config_id: Optional[str] = Field(None, description="Memory configuration ID") + config_id: str = Field(..., description="Memory configuration ID (required)") storage_type: str = Field("neo4j", description="Storage type: neo4j or rag") user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID") @@ -132,3 +132,79 @@ class MemoryReadResponse(BaseModel): description="Intermediate retrieval outputs" ) end_user_id: str = Field(..., description="End user ID") + + +class CreateEndUserRequest(BaseModel): + """Request schema for creating an end user. + + Attributes: + workspace_id: Workspace ID (required) + other_id: External user identifier (required) + other_name: Display name for the end user + """ + workspace_id: str = Field(..., description="Workspace ID (required)") + other_id: str = Field(..., description="External user identifier (required)") + other_name: Optional[str] = Field("", description="Display name") + + @field_validator("workspace_id") + @classmethod + def validate_workspace_id(cls, v: str) -> str: + """Validate that workspace_id is not empty.""" + if not v or not v.strip(): + raise ValueError("workspace_id is required and cannot be empty") + return v.strip() + + @field_validator("other_id") + @classmethod + def validate_other_id(cls, v: str) -> str: + """Validate that other_id is not empty.""" + if not v or not v.strip(): + raise ValueError("other_id is required and cannot be empty") + return v.strip() + + +class CreateEndUserResponse(BaseModel): + """Response schema for end user creation. + + Attributes: + id: Created end user UUID + other_id: External user identifier + other_name: Display name + workspace_id: Workspace the user belongs to + """ + id: str = Field(..., description="End user UUID") + other_id: str = Field(..., description="External user identifier") + other_name: str = Field("", description="Display name") + workspace_id: str = Field(..., description="Workspace ID") + + +class MemoryConfigItem(BaseModel): + """Schema for a single memory config in the list response. + + Attributes: + config_id: Configuration UUID + config_name: Configuration name + config_desc: Configuration description + is_default: Whether this is the workspace default config + scene_name: Associated ontology scene name + created_at: Creation timestamp + updated_at: Last update timestamp + """ + config_id: str = Field(..., description="Configuration ID") + config_name: str = Field(..., description="Configuration name") + config_desc: Optional[str] = Field(None, description="Configuration description") + is_default: bool = Field(False, description="Whether this is the workspace default") + scene_name: Optional[str] = Field(None, description="Associated ontology scene name") + created_at: Optional[str] = Field(None, description="Creation timestamp") + updated_at: Optional[str] = Field(None, description="Last update timestamp") + + +class ListConfigsResponse(BaseModel): + """Response schema for listing memory configs. + + Attributes: + configs: List of memory config items + total: Total number of configs + """ + configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs") + total: int = Field(0, description="Total number of configs") diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..6fcf680b 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -118,28 +118,54 @@ class AppChatService: ) + model_info = ModelInfo( + model_name=api_key_obj.model_name, + provider=api_key_obj.provider, + api_key=api_key_obj.api_key, + api_base=api_key_obj.api_base, + capability=api_key_obj.capability, + is_omni=api_key_obj.is_omni, + model_type=ModelType.LLM + ) + # 加载历史消息 messages = self.conversation_service.get_messages( conversation_id=conversation_id, limit=10 ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] + history = [] + for msg in messages: + content = [{"type": "text", "text": msg.content}] + + # 处理 meta_data 中的 files + if msg.meta_data and msg.meta_data.get("files"): + files = msg.meta_data.get("files", []) + # 使用 MultimodalService 处理文件 + multimodal_service = MultimodalService(self.db, api_config=model_info) + + # 将 files 转换为 FileInput 格式 + file_inputs = [] + for file in files: + from app.schemas.app_schema import FileInput, TransferMethod + file_input = FileInput( + type=file.get("type"), + transfer_method=TransferMethod.REMOTE_URL, + url=file.get("url") + ) + file_inputs.append(file_input) + + history_processed_files = await multimodal_service.history_process_files(files=file_inputs) + + content.extend(history_processed_files) + + history.append({ + "role": msg.role, + "content": content + }) # 处理多模态文件 processed_files = None if files: - model_info = ModelInfo( - model_name=api_key_obj.model_name, - provider=api_key_obj.provider, - api_key=api_key_obj.api_key, - api_base=api_key_obj.api_base, - capability=api_key_obj.capability, - is_omni=api_key_obj.is_omni, - model_type=ModelType.LLM - ) multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件") @@ -313,31 +339,54 @@ class AppChatService: streaming=True ) + model_info = ModelInfo( + model_name=api_key_obj.model_name, + provider=api_key_obj.provider, + api_key=api_key_obj.api_key, + api_base=api_key_obj.api_base, + capability=api_key_obj.capability, + is_omni=api_key_obj.is_omni, + model_type=ModelType.LLM + ) + # 加载历史消息 + messages = self.conversation_service.get_messages( + conversation_id=conversation_id, + limit=10 + ) history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] + for msg in messages: + content = [{"type": "text", "text": msg.content}] + + # 处理 meta_data 中的 files + if msg.meta_data and msg.meta_data.get("files"): + history_files = msg.meta_data.get("files", []) + # 使用 MultimodalService 处理文件 + multimodal_service = MultimodalService(self.db, api_config=model_info) + + # 将 files 转换为 FileInput 格式 + file_inputs = [] + for file in history_files: + from app.schemas.app_schema import FileInput, TransferMethod + file_input = FileInput( + type=file.get("type"), + transfer_method=TransferMethod.REMOTE_URL, + url=file.get("url") + ) + file_inputs.append(file_input) + + history_processed_files = await multimodal_service.history_process_files(files=file_inputs) + + content.extend(history_processed_files) + + history.append({ + "role": msg.role, + "content": content + }) # 处理多模态文件 processed_files = None if files: - model_info = ModelInfo( - model_name=api_key_obj.model_name, - provider=api_key_obj.provider, - api_key=api_key_obj.api_key, - api_base=api_key_obj.api_base, - capability=api_key_obj.capability, - is_omni=api_key_obj.is_omni, - model_type=ModelType.LLM - ) multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件") @@ -347,8 +396,14 @@ class AppChatService: total_tokens = 0 text_queue: asyncio.Queue = asyncio.Queue() + api_key_config = { + "model_name": api_key_obj.model_name, + "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base, + "provider": api_key_obj.provider, + } stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming( - features_config, api_key_obj, + features_config, api_key_config, text_queue=text_queue, tenant_id=tenant_id, workspace_id=workspace_id ) diff --git a/api/app/services/app_dsl_service.py b/api/app/services/app_dsl_service.py index a10aa70a..8c198be4 100644 --- a/api/app/services/app_dsl_service.py +++ b/api/app/services/app_dsl_service.py @@ -16,6 +16,7 @@ from app.models.app_release_model import AppRelease from app.models.knowledge_model import Knowledge from app.models.models_model import ModelConfig from app.models.tool_model import ToolConfig as ToolConfigModel +from app.models.skill_model import Skill from app.models.workflow_model import WorkflowConfig from app.services.workflow_service import WorkflowService from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter @@ -84,7 +85,9 @@ class AppDslService: if "knowledge_retrieval" in cfg: enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"]) if "tools" in cfg: - enriched["tools"] = self._enrich_tools(cfg["tools"]) + enriched["tools"] = self._enrich_tools(cfg.get("tools")) + if "skills" in cfg: + enriched["skills"] = self._enrich_skills(cfg.get("skills")) return enriched if app_type == AppType.MULTI_AGENT: enriched = {**cfg} @@ -108,6 +111,7 @@ class AppDslService: "variables": config.variables if config else [], "edges": config.edges if config else [], "nodes": config.nodes if config else [], + "features": config.features if config else {}, "execution_config": config.execution_config if config else {}, "triggers": config.triggers if config else [], } if config else {} @@ -123,7 +127,8 @@ class AppDslService: "memory": config.memory if config else None, "variables": config.variables if config else [], "tools": self._enrich_tools(config.tools) if config else [], - "skills": config.skills if config else {}, + "skills": self._enrich_skills(config.skills) if config else {}, + "features": config.features if config else {} } if config else {} dsl = {**meta, "app": app_meta, "agent_config": config_data} @@ -185,6 +190,22 @@ class AppDslService: def _enrich_tools(self, tools: list) -> list: return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])] + def _skill_ref(self, skill_id) -> Optional[dict]: + if not skill_id: + return None + s = self.db.query(Skill).filter(Skill.id == skill_id).first() + return {"id": str(skill_id), "name": s.name} if s else {"id": str(skill_id)} + + def _enrich_skills(self, skills: Optional[dict]) -> Optional[dict]: + if not skills: + return skills + skill_ids = skills.get("skill_ids", []) + enriched_ids = [ + {"id": sid, "_ref": self._skill_ref(sid)} + for sid in (skill_ids or []) + ] + return {**skills, "skill_ids": enriched_ids} + def _agent_ref(self, agent_id) -> Optional[dict]: if not agent_id: return None @@ -249,7 +270,8 @@ class AppDslService: memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings), variables=cfg.get("variables", []), tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings), - skills=cfg.get("skills", {}), + skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings), + features=cfg.get("features", {}), is_active=True, created_at=now, updated_at=now, @@ -290,6 +312,7 @@ class AppDslService: edges=[e.model_dump() for e in result.edges], variables=[v.model_dump() for v in result.variables], execution_config=wf.get("execution_config", {}), + features=wf.get("features", {}), triggers=wf.get("triggers", []), validate=False, ) @@ -444,6 +467,46 @@ class AppDslService: return {**memory, "memory_config_id": None, "enabled": False} return memory + def _resolve_skills(self, skills: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> dict: + if not skills: + return skills or {} + resolved_ids = [] + for entry in (skills.get("skill_ids") or []): + # entry 可能是 {"id": "...", "_ref": {...}} 或直接是字符串 + if isinstance(entry, dict): + ref = entry.get("_ref") or ({"name": None, "id": entry.get("id")} if entry.get("id") else None) + skill_id = self._resolve_skill(ref, tenant_id, warnings) + else: + skill_id = self._resolve_skill({"id": str(entry)}, tenant_id, warnings) + if skill_id: + resolved_ids.append(str(skill_id)) + return {**{k: v for k, v in skills.items() if k != "skill_ids"}, "skill_ids": resolved_ids} + + def _resolve_skill(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]: + if not ref: + return None + # 先按 id 匹配 + if ref.get("id"): + try: + s = self.db.query(Skill).filter( + Skill.id == uuid.UUID(str(ref["id"])), + Skill.tenant_id == tenant_id + ).first() + if s: + return str(s.id) + except Exception: + pass + # 再按名称匹配 + if ref.get("name"): + s = self.db.query(Skill).filter( + Skill.name == ref["name"], + Skill.tenant_id == tenant_id + ).first() + if s: + return str(s.id) + warnings.append(f"未找到技能: {ref}") + return None + def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list: result = [] for t in (tools or []): diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 68d255f8..19aaac42 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -833,8 +833,6 @@ class AppService: # 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用 target_tenant_id = None - available_model_ids: set = set() - available_kb_ids: set = set() if is_cross_workspace: target_ws = self.db.get(Workspace, target_workspace_id) if not target_ws: @@ -849,28 +847,29 @@ class AppService: if source_config: if is_cross_workspace: - # Batch-collect and preload all referenced resources - model_ids, kb_ids = self._collect_resource_ids_from_config( - source_config.default_model_config_id, - source_config.knowledge_retrieval, - source_config.tools + # 跨工作空间:model/tools/skills 属于 tenant 级别直接保留, + # knowledge_bases 属于 workspace 级别需过滤,memory_config 需清空 + _, kb_ids = self._collect_resource_ids_from_config( + None, source_config.knowledge_retrieval ) - available_model_ids, available_kb_ids = self._preload_cross_workspace_resources( - target_tenant_id, target_workspace_id, model_ids, kb_ids - ) - new_model_config_id = self._is_model_available( - source_config.default_model_config_id, available_model_ids + _, available_kb_ids = self._preload_cross_workspace_resources( + target_tenant_id, target_workspace_id, set(), kb_ids ) + new_model_config_id = source_config.default_model_config_id new_knowledge_retrieval = self._clean_knowledge_retrieval( source_config.knowledge_retrieval, available_kb_ids ) - new_tools = self._clean_tools( - source_config.tools, available_kb_ids + new_tools = copy.deepcopy(source_config.tools) if source_config.tools else [] + new_memory = self._clean_memory_cross_workspace( + source_config.memory, target_workspace_id ) + new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {} else: new_model_config_id = source_config.default_model_config_id new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None new_tools = copy.deepcopy(source_config.tools) if source_config.tools else [] + new_memory = copy.deepcopy(source_config.memory) if source_config.memory else None + new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {} new_config = AgentConfig( id=uuid.uuid4(), @@ -879,9 +878,11 @@ class AppService: default_model_config_id=new_model_config_id, model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None, knowledge_retrieval=new_knowledge_retrieval, - memory=copy.deepcopy(source_config.memory) if source_config.memory else None, + memory=new_memory, variables=copy.deepcopy(source_config.variables) if source_config.variables else [], tools=new_tools, + skills=new_skills, + features=copy.deepcopy(source_config.features) if source_config.features else {}, is_active=True, created_at=now, updated_at=now, @@ -894,28 +895,14 @@ class AppService: ).first() if source_config: - if is_cross_workspace: - model_ids, kb_ids = self._collect_resource_ids_from_workflow_nodes( - source_config.nodes - ) - available_model_ids, available_kb_ids = self._preload_cross_workspace_resources( - target_tenant_id, target_workspace_id, model_ids, kb_ids - ) - new_nodes = self._clean_workflow_nodes_for_cross_workspace( - source_config.nodes or [], - available_model_ids, - available_kb_ids - ) - else: - new_nodes = copy.deepcopy(source_config.nodes) if source_config.nodes else [] - new_config = WorkflowConfig( id=uuid.uuid4(), app_id=new_app.id, - nodes=new_nodes, + nodes=copy.deepcopy(source_config.nodes) if source_config.nodes else [], edges=copy.deepcopy(source_config.edges) if source_config.edges else [], variables=copy.deepcopy(source_config.variables) if source_config.variables else [], execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {}, + features=copy.deepcopy(source_config.features) if source_config.features else {}, triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [], is_active=True, created_at=now, @@ -929,24 +916,15 @@ class AppService: ).first() if source_config: - if is_cross_workspace: - model_ids = {source_config.default_model_config_id} if source_config.default_model_config_id else set() - available_model_ids, _ = self._preload_cross_workspace_resources( - target_tenant_id, target_workspace_id, model_ids, set() - ) - new_model_config_id = self._is_model_available( - source_config.default_model_config_id, available_model_ids - ) - else: - new_model_config_id = source_config.default_model_config_id - + # multi_agent 的 model_config_id/sub_agents/routing_rules 均属于 tenant 级别直接保留 + # 跨空间时 master_agent_id(AppRelease)属于源空间,需清空 new_config = MultiAgentConfig( id=uuid.uuid4(), app_id=new_app.id, master_agent_id=source_config.master_agent_id if not is_cross_workspace else None, master_agent_name=source_config.master_agent_name, - default_model_config_id=new_model_config_id, - model_parameters=source_config.model_parameters, + default_model_config_id=source_config.default_model_config_id, + model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None, orchestration_mode=source_config.orchestration_mode, sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [], routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None, @@ -1037,8 +1015,7 @@ class AppService: @staticmethod def _collect_resource_ids_from_config( model_config_id: Optional[uuid.UUID], - knowledge_retrieval: Optional[dict], - tools: Optional[list] + knowledge_retrieval: Optional[dict] ) -> tuple: """Extract all model config IDs and knowledge base IDs from an app config.""" model_ids: set = set() @@ -1048,62 +1025,12 @@ class AppService: model_ids.add(model_config_id) if knowledge_retrieval and isinstance(knowledge_retrieval, dict): - if "kb_ids" in knowledge_retrieval: - for kid in knowledge_retrieval.get("kb_ids", []): - if kid: - kb_ids.add(str(kid)) - if knowledge_retrieval.get("knowledge_id"): - kb_ids.add(str(knowledge_retrieval["knowledge_id"])) - - if tools: - for tool in tools: - if isinstance(tool, dict): - kid = tool.get("knowledge_id") or tool.get("kb_id") - if kid: - kb_ids.add(str(kid)) + if "knowledge_bases" in knowledge_retrieval: + for kid in knowledge_retrieval.get("knowledge_bases", []): + kb_ids.add(str(kid.get("kb_id"))) return model_ids, kb_ids - @staticmethod - def _collect_resource_ids_from_workflow_nodes(nodes: list) -> tuple: - """Extract all model config IDs and knowledge base IDs from workflow nodes.""" - model_ids: set = set() - kb_ids: set = set() - - for node in (nodes or []): - if not isinstance(node, dict): - continue - data = node.get("data", {}) - if not isinstance(data, dict): - continue - for key in ("model_config_id", "default_model_config_id"): - val = data.get(key) - if val: - try: - model_ids.add(uuid.UUID(str(val))) - except (ValueError, AttributeError): - pass - kr = data.get("knowledge_retrieval") - if isinstance(kr, dict): - for kid in kr.get("kb_ids", []): - if kid: - kb_ids.add(str(kid)) - if kr.get("knowledge_id"): - kb_ids.add(str(kr["knowledge_id"])) - if data.get("knowledge_id"): - kb_ids.add(str(data["knowledge_id"])) - for kid in data.get("kb_ids", []): - if kid: - kb_ids.add(str(kid)) - - return model_ids, kb_ids - - @staticmethod - def _is_model_available(model_config_id: Optional[uuid.UUID], available_model_ids: set) -> Optional[uuid.UUID]: - if not model_config_id: - return None - return model_config_id if model_config_id in available_model_ids else None - @staticmethod def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]: if not kb_id: @@ -1124,95 +1051,53 @@ class AppService: cleaned = copy.deepcopy(knowledge_retrieval) - if "kb_ids" in cleaned and isinstance(cleaned["kb_ids"], list): - cleaned["kb_ids"] = [ - kid for kid in cleaned["kb_ids"] - if self._is_kb_available(kid, available_kb_ids) + if "knowledge_bases" in cleaned and isinstance(cleaned["knowledge_bases"], list): + cleaned["knowledge_bases"] = [ + kb for kb in cleaned["knowledge_bases"] + if self._is_kb_available(kb.get("kb_id"), available_kb_ids) ] - if "knowledge_id" in cleaned: - cleaned["knowledge_id"] = self._is_kb_available( - cleaned.get("knowledge_id"), available_kb_ids - ) - return cleaned - def _clean_tools( + def _clean_memory_cross_workspace( self, - tools: Optional[list], - available_kb_ids: set - ) -> list: - """Clean tools config, keeping built-in tools and tools with available KBs.""" - if not tools: - return [] + memory: Optional[dict], + target_workspace_id: uuid.UUID + ) -> Optional[dict]: + """Clear memory_config_id/memory_content if it doesn't belong to target workspace.""" + if not memory: + return None - cleaned = [] - for tool in tools: - if not isinstance(tool, dict): - cleaned.append(tool) - continue + from app.models.memory_config_model import MemoryConfig - tool_type = tool.get("type", "") - if tool_type in ("builtin", "built_in", "system"): - cleaned.append(copy.deepcopy(tool)) - continue + cleaned = copy.deepcopy(memory) + # 兼容旧字段 memory_content 和新字段 memory_config_id + mid = cleaned.get("memory_config_id") or cleaned.get("memory_content") + if mid: + try: + mid_uuid = uuid.UUID(str(mid)) + except (ValueError, AttributeError): + exists = self.db.query(MemoryConfig).filter( + MemoryConfig.config_id_old == int(mid), + MemoryConfig.workspace_id == target_workspace_id + ).first() + if not exists: + cleaned["memory_config_id"] = None + cleaned.pop("memory_content", None) + cleaned["enabled"] = False + return cleaned - kb_id = tool.get("knowledge_id") or tool.get("kb_id") - if kb_id: - if self._is_kb_available(kb_id, available_kb_ids): - cleaned.append(copy.deepcopy(tool)) - continue + exists = self.db.query( + self.db.query(MemoryConfig).filter( + MemoryConfig.config_id == mid_uuid, + MemoryConfig.workspace_id == target_workspace_id + ).exists() + ).scalar() + if not exists: + cleaned["memory_config_id"] = None + cleaned.pop("memory_content", None) + cleaned["enabled"] = False - cleaned.append(copy.deepcopy(tool)) - - return cleaned - - def _clean_workflow_nodes_for_cross_workspace( - self, - nodes: list, - available_model_ids: set, - available_kb_ids: set - ) -> list: - """Clean workflow nodes, using pre-loaded resource sets. Uses deepcopy to avoid mutating source.""" - if not nodes: - return [] - - cleaned = [] - for node in nodes: - if not isinstance(node, dict): - cleaned.append(node) - continue - - node_copy = copy.deepcopy(node) - data = node_copy.get("data") - if not isinstance(data, dict): - cleaned.append(node_copy) - continue - - for key in ("model_config_id", "default_model_config_id"): - if key in data and data[key]: - try: - mid = uuid.UUID(str(data[key])) - except (ValueError, AttributeError): - data[key] = None - continue - data[key] = str(mid) if mid in available_model_ids else None - - if "knowledge_retrieval" in data and data["knowledge_retrieval"]: - data["knowledge_retrieval"] = self._clean_knowledge_retrieval( - data["knowledge_retrieval"], available_kb_ids - ) - if "knowledge_id" in data: - data["knowledge_id"] = self._is_kb_available( - data.get("knowledge_id"), available_kb_ids - ) - if "kb_ids" in data and isinstance(data["kb_ids"], list): - data["kb_ids"] = [ - kid for kid in data["kb_ids"] - if self._is_kb_available(kid, available_kb_ids) - ] - - cleaned.append(node_copy) return cleaned def list_apps( diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index aff5f533..f8a01a40 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -21,6 +21,7 @@ from app.models.conversation_model import ConversationDetail from app.models.prompt_optimizer_model import RoleType from app.repositories.conversation_repository import ConversationRepository, MessageRepository from app.schemas.conversation_schema import ConversationOut +from app.schemas.model_schema import ModelInfo from app.services import workspace_service from app.services.model_service import ModelConfigService @@ -119,25 +120,27 @@ class ConversationService: def get_user_conversations( self, - user_id: uuid.UUID - ) -> list[Conversation]: + user_id: uuid.UUID, + page: int = 1, + page_size: int = 20 + ) -> tuple[list[Conversation], int]: """ - Retrieve recent conversations for a specific user - - This method delegates persistence logic to the repository layer and - applies service-level defaults (e.g. recent conversation limit). + Retrieve recent conversations for a specific user with pagination. Args: user_id (uuid.UUID): Unique identifier of the user. + page (int): Page number (1-based). Defaults to 1. + page_size (int): Number of items per page. Defaults to 20. Returns: - list[Conversation]: A list of recent conversation entities. + tuple[list[Conversation], int]: A list of recent conversation entities and total count. """ - conversations = self.conversation_repo.get_conversation_by_user_id( + conversations, total = self.conversation_repo.get_conversation_by_user_id( user_id, - limit=10 + page=page, + page_size=page_size ) - return conversations + return conversations, total def list_conversations( self, @@ -267,10 +270,11 @@ class ConversationService: return messages - def get_conversation_history( + async def get_conversation_history( self, conversation_id: uuid.UUID, - max_history: Optional[int] = None + max_history: Optional[int] = None, + api_config: Optional[ModelInfo] = None ) -> List[dict]: """ Retrieve historical conversation messages formatted as dictionaries. @@ -278,6 +282,7 @@ class ConversationService: Args: conversation_id (uuid.UUID): Conversation UUID. max_history (Optional[int]): Maximum number of messages to retrieve. + api_config (Optional[ModelInfo]): Model API configuration for multimodal processing. Returns: List[dict]: List of message dictionaries with keys 'role' and 'content'. @@ -288,13 +293,37 @@ class ConversationService: ) # 转换为字典格式 - history = [ - { + history = [] + for msg in messages: + content = [{"type": "text", "text": msg.content}] + + # 处理 meta_data 中的 files + if msg.meta_data and msg.meta_data.get("files"): + files = msg.meta_data.get("files", []) + if api_config: + # 使用 MultimodalService 处理文件 + from app.services.multimodal_service import MultimodalService + multimodal_service = MultimodalService(self.db, api_config=api_config) + + # 将 files 转换为 FileInput 格式 + file_inputs = [] + for file in files: + from app.schemas.app_schema import FileInput, TransferMethod + file_input = FileInput( + type=file.get("type"), + transfer_method=TransferMethod.REMOTE_URL, + url=file.get("url") + ) + file_inputs.append(file_input) + + processed_files = await multimodal_service.history_process_files(files=file_inputs) + + content.extend(processed_files) + + history.append({ "role": msg.role, - "content": msg.content - } - for msg in messages - ] + "content": content + }) return history @@ -522,9 +551,18 @@ class ConversationService: type=ModelType(model_type) ) - conversation_messages = self.get_conversation_history( + conversation_messages = await self.get_conversation_history( conversation_id=conversation_id, - max_history=20 + max_history=20, + api_config=ModelInfo( + model_name=model_name, + provider=provider, + api_key=api_key, + api_base=api_base, + capability=api_config.capability, + is_omni=api_config.is_omni, + model_type=model_type + ) ) if len(conversation_messages) == 0: return ConversationOut( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..5989f0f8 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -579,9 +579,20 @@ class AgentRunService: user_id=user_id ) + model_info = ModelInfo( + model_name=api_key_config["model_name"], + provider=api_key_config["provider"], + api_key=api_key_config["api_key"], + api_base=api_key_config["api_base"], + capability=api_key_config["capability"], + is_omni=api_key_config["is_omni"], + model_type=model_config.type + ) + # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, + api_config=model_info, max_history=10 ) @@ -589,15 +600,6 @@ class AgentRunService: processed_files = None if files: # 获取 provider 信息 - model_info = ModelInfo( - model_name=api_key_config["model_name"], - provider=api_key_config["provider"], - api_key=api_key_config["api_key"], - api_base=api_key_config["api_base"], - capability=api_key_config["capability"], - is_omni=api_key_config["is_omni"], - model_type=ModelType.LLM - ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(user_id, files) @@ -815,9 +817,20 @@ class AgentRunService: sub_agent=sub_agent ) + model_info = ModelInfo( + model_name=api_key_config["model_name"], + provider=api_key_config["provider"], + api_key=api_key_config["api_key"], + api_base=api_key_config["api_base"], + capability=api_key_config["capability"], + is_omni=api_key_config["is_omni"], + model_type=model_config.type + ) + # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, + api_config=model_info, max_history=memory_config.get("max_history", 10) ) @@ -825,15 +838,6 @@ class AgentRunService: processed_files = None if files: # 获取 provider 信息 - model_info = ModelInfo( - model_name=api_key_config["model_name"], - provider=api_key_config["provider"], - api_key=api_key_config["api_key"], - api_base=api_key_config["api_base"], - capability=api_key_config["capability"], - is_omni=api_key_config["is_omni"], - model_type=ModelType.LLM - ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(user_id, files) @@ -1115,6 +1119,7 @@ class AgentRunService: async def _load_conversation_history( self, conversation_id: str, + api_config: ModelInfo | None = None, max_history: int = 10 ) -> List[Dict[str, str]]: """加载会话历史消息 @@ -1129,9 +1134,11 @@ class AgentRunService: try: conversation_service = ConversationService(self.db) - history = conversation_service.get_conversation_history( + # 获取 API 配置用于多模态处理 + history = await conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), - max_history=max_history + max_history=max_history, + api_config=api_config ) logger.debug( diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1e1d9e45..af9a04e2 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -1179,7 +1179,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An app = db.query(App).filter(App.id == app_id).first() if not app: logger.warning(f"App not found: {app_id}") - raise ValueError(f"应用不存在: {app_id}") + # raise ValueError(f"应用不存在: {app_id}") # TODO: temp fix for draft run # if not app.current_release_id: # logger.warning(f"No current release for app: {app_id}") @@ -1252,17 +1252,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An memory_config_service = MemoryConfigService(db) memory_config = memory_config_service.get_config_with_fallback( memory_config_id=memory_config_id_to_use, - workspace_id=app.workspace_id + workspace_id=end_user.workspace_id ) memory_config_id = str(memory_config.config_id) if memory_config else None result = { "end_user_id": str(end_user_id), - "app_id": str(app_id), - "release_id": str(app.current_release_id) if app.current_release_id else None, "memory_config_id": memory_config_id, - "workspace_id": str(app.workspace_id) + "workspace_id": str(end_user.workspace_id) } logger.info( diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index f86fbed8..01bc6267 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -84,43 +84,65 @@ class MemoryAPIService: if not app: logger.warning(f"App not found for end_user: {end_user_id}") - raise ResourceNotFoundException( - resource_type="App", - resource_id=str(end_user.app_id) - ) - - if app.workspace_id != workspace_id: - logger.warning( - f"End user {end_user_id} belongs to workspace {app.workspace_id}, " - f"not authorized workspace {workspace_id}" - ) - raise BusinessException( - message="End user does not belong to authorized workspace", - code=BizCode.FORBIDDEN - ) + # raise ResourceNotFoundException( + # resource_type="App", + # resource_id=str(end_user.app_id) + # ) + # temporally allow any workspace to access + # if end_user.workspace_id != workspace_id: + # print(f"[DEBUG] end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}") + # logger.warning( + # f"End user {end_user_id} belongs to workspace {end_user.workspace_id}, " + # f"not authorized workspace {workspace_id}" + # ) + # raise BusinessException( + # message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}", + # code=BizCode.FORBIDDEN + # ) logger.info(f"End user {end_user_id} validated successfully") return end_user - + + def _update_end_user_config(self, end_user_id: str, config_id: str) -> None: + """Update the end user's memory_config_id. + + Silently updates the config association. Logs warnings on failure + but does not raise, so it won't block the main read/write operation. + + Args: + end_user_id: End user identifier + config_id: Memory configuration ID to assign + """ + try: + config_uuid = uuid.UUID(config_id) + from app.repositories.end_user_repository import EndUserRepository + end_user_repo = EndUserRepository(self.db) + end_user_repo.update_memory_config_id( + end_user_id=uuid.UUID(end_user_id), + memory_config_id=config_uuid, + ) + except Exception as e: + logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") + async def write_memory( self, workspace_id: uuid.UUID, end_user_id: str, message: str, - config_id: Optional[str] = None, + config_id: str, storage_type: str = "neo4j", user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. - Validates end_user exists and belongs to workspace, then delegates - to MemoryAgentService.write_memory. + Validates end_user exists and belongs to workspace, updates the end user's + memory_config_id, then delegates to MemoryAgentService.write_memory. Args: workspace_id: Workspace ID for resource validation end_user_id: End user identifier (used as end_user_id) message: Message content to store - config_id: Optional memory configuration ID + config_id: Memory configuration ID (required) storage_type: Storage backend (neo4j or rag) user_rag_memory_id: Optional RAG memory ID @@ -136,7 +158,8 @@ class MemoryAPIService: # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - # Use end_user_id as end_user_id for memory operations + # Update end user's memory_config_id + self._update_end_user_config(end_user_id, config_id) try: # Delegate to MemoryAgentService @@ -188,21 +211,21 @@ class MemoryAPIService: end_user_id: str, message: str, search_switch: str = "0", - config_id: Optional[str] = None, + config_id: str = "", storage_type: str = "neo4j", user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Read memory with validation. - Validates end_user exists and belongs to workspace, then delegates - to MemoryAgentService.read_memory. + Validates end_user exists and belongs to workspace, updates the end user's + memory_config_id, then delegates to MemoryAgentService.read_memory. Args: workspace_id: Workspace ID for resource validation end_user_id: End user identifier (used as end_user_id) message: Query message search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) - config_id: Optional memory configuration ID + config_id: Memory configuration ID (required) storage_type: Storage backend (neo4j or rag) user_rag_memory_id: Optional RAG memory ID @@ -218,7 +241,8 @@ class MemoryAPIService: # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - # Use end_user_id as end_user_id for memory operations + # Update end user's memory_config_id + self._update_end_user_config(end_user_id, config_id) try: @@ -256,3 +280,50 @@ class MemoryAPIService: message=f"Memory read failed: {str(e)}", code=BizCode.MEMORY_READ_FAILED ) + + def list_memory_configs( + self, + workspace_id: uuid.UUID, + ) -> Dict[str, Any]: + """List all memory configs for a workspace. + + Args: + workspace_id: Workspace ID from API key authorization + + Returns: + Dict with configs list and total count + + Raises: + BusinessException: If listing fails + """ + logger.info(f"Listing memory configs for workspace: {workspace_id}") + + try: + from app.repositories.memory_config_repository import MemoryConfigRepository + + results = MemoryConfigRepository.get_all(self.db, workspace_id=workspace_id) + + configs = [] + for config, scene_name in results: + configs.append({ + "config_id": str(config.config_id), + "config_name": config.config_name, + "config_desc": config.config_desc, + "is_default": config.is_default or False, + "scene_name": scene_name, + "created_at": config.created_at.isoformat() if config.created_at else None, + "updated_at": config.updated_at.isoformat() if config.updated_at else None, + }) + + logger.info(f"Found {len(configs)} memory configs for workspace {workspace_id}") + return { + "configs": configs, + "total": len(configs), + } + + except Exception as e: + logger.error(f"Failed to list memory configs for workspace {workspace_id}: {e}") + raise BusinessException( + message=f"Failed to list memory configs: {str(e)}", + code=BizCode.MEMORY_READ_FAILED + ) diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 84c4aff6..a0bcc1a1 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -619,7 +619,7 @@ class MemoryForgetService: recent_trends.append({ 'date': date_str, 'merged_count': record.merged_count, - 'average_activation': record.average_activation_value, + 'average_activation': round(record.average_activation_value, 2) if record.average_activation_value is not None else None, 'total_nodes': record.total_nodes, 'execution_time': int(record.execution_time.timestamp() * 1000) }) diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index f0c7cee2..6cb0a7f0 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -11,6 +11,8 @@ import base64 import io import uuid +import zipfile +import chardet from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional @@ -42,12 +44,10 @@ PDF_MIME = ['application/pdf'] DOC_MIME = [ 'application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - 'application/zip' ] XLSX_MIME = [ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'application/vnd.ms-excel', - 'application/zip' ] CSV_MIME = ['text/csv', 'application/csv'] JSON_MIME = ['application/json'] @@ -418,6 +418,71 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result + async def history_process_files( + self, + files: Optional[List[FileInput]], + ) -> List[Dict[str, Any]]: + """ + 处理文件列表,返回 LLM 可用的格式 + + Args: + files: 文件输入列表 + + Returns: + List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式) + """ + if not files: + return [] + + # 获取对应的策略 + # dashscope 的 omni 模型使用 OpenAI 兼容格式 + if self.provider == "dashscope" and self.is_omni: + strategy_class = OpenAIFormatStrategy + else: + strategy_class = PROVIDER_STRATEGIES.get(self.provider) + if not strategy_class: + logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") + strategy_class = DashScopeFormatStrategy + + result = [] + for idx, file in enumerate(files): + strategy = strategy_class(file) + if not file.url: + file.url = await self.get_file_url(file) + try: + if file.type == FileType.IMAGE and "vision" in self.capability: + is_support, content = await self._process_image(file, strategy) + result.append(content) + elif file.type == FileType.DOCUMENT: + is_support, content = await self._process_document(file, strategy) + result.append(content) + elif file.type == FileType.AUDIO and "audio" in self.capability: + is_support, content = await self._process_audio(file, strategy) + result.append(content) + elif file.type == FileType.VIDEO and "video" in self.capability: + is_support, content = await self._process_video(file, strategy) + result.append(content) + else: + logger.warning(f"不支持的文件类型: {file.type}") + except Exception as e: + logger.error( + f"处理文件失败", + extra={ + "file_index": idx, + "file_type": file.type, + "error": str(e) + }, + exc_info=True + ) + # 继续处理其他文件,不中断整个流程 + result.append({ + "type": "text", + "text": f"[文件处理失败: {str(e)}]" + }) + + logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") + return result + def write_perceptual_memory( self, end_user_id: str, @@ -588,12 +653,12 @@ class MultimodalService: file.set_content(file_content) file_mime_type = magic.from_buffer(file_content, mime=True) if file_mime_type in TEXT_MIME: - return file_content.decode("utf-8") + return self._decode_text_safe(file_content) elif file_mime_type in PDF_MIME: return await self._extract_pdf_text(file_content) - elif file_mime_type in DOC_MIME and file.file_type.endswith(('docx', 'doc')): + elif self._is_word_file(file_content, file_mime_type): return await self._extract_word_text(file_content) - elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")): + elif self._is_excel_file(file_content, file_mime_type): return await self._extract_xlsx_text(file_content) elif file_mime_type in CSV_MIME: return await self._extract_csv_text(file_content) @@ -622,52 +687,156 @@ class MultimodalService: @staticmethod async def _extract_word_text(file_content: bytes) -> str: - """提取 Word 文档文本""" + """提取 Word 文档文本(支持 .docx 和旧版 .doc)""" + # 先尝试 docx(ZIP 格式) + if file_content[:2] == b'PK': + try: + word_file = io.BytesIO(file_content) + doc = Document(word_file) + return '\n'.join(p.text for p in doc.paragraphs) + except Exception as e: + logger.error(f"提取 docx 文本失败: {e}") + return f"[docx 提取失败: {str(e)}]" + + # 旧版 .doc(OLE2 格式) try: - word_file = io.BytesIO(file_content) - doc = Document(word_file) - text_parts = [paragraph.text for paragraph in doc.paragraphs] - return '\n'.join(text_parts) + import olefile + ole = olefile.OleFileIO(io.BytesIO(file_content)) + if not ole.exists('WordDocument'): + return "[doc 提取失败: 未找到 WordDocument 流]" + # 读取 WordDocument 流,提取可见 ASCII/Unicode 文本 + stream = ole.openstream('WordDocument').read() + # Word Binary Format: 文本在流中以 UTF-16-LE 编码存储 + # 简单提取:过滤出可打印字符段 + try: + text = stream.decode('utf-16-le', errors='ignore') + except Exception: + text = stream.decode('latin-1', errors='ignore') + # 过滤控制字符,保留可打印内容 + import re + text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) + text = re.sub(r' +', ' ', text).strip() + ole.close() + return text except Exception as e: - logger.error(f"提取 Word 文本失败: {e}") - return f"[Word 提取失败: {str(e)}]" + logger.error(f"提取 doc 文本失败: {e}") + return f"[doc 提取失败: {str(e)}]" @staticmethod async def _extract_xlsx_text(file_content: bytes) -> str: - """提取 Excel 文本""" + """提取 Excel 文本(支持 .xlsx 和旧版 .xls)""" + # xlsx(ZIP 格式) + if file_content[:2] == b'PK': + try: + wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True) + parts = [] + for sheet in wb.worksheets: + parts.append(f"[Sheet: {sheet.title}]") + for row in sheet.iter_rows(values_only=True): + parts.append('\t'.join('' if v is None else str(v) for v in row)) + return '\n'.join(parts) + except Exception as e: + logger.error(f"提取 xlsx 文本失败: {e}") + return f"[xlsx 提取失败: {str(e)}]" + + # xls(OLE2/BIFF 格式) try: - wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True) + import xlrd + wb = xlrd.open_workbook(file_contents=file_content) parts = [] - for sheet in wb.worksheets: - parts.append(f"[Sheet: {sheet.title}]") - for row in sheet.iter_rows(values_only=True): - parts.append('\t'.join('' if v is None else str(v) for v in row)) + for sheet in wb.sheets(): + parts.append(f"[Sheet: {sheet.name}]") + for row_idx in range(sheet.nrows): + parts.append('\t'.join(str(sheet.cell_value(row_idx, col)) for col in range(sheet.ncols))) return '\n'.join(parts) except Exception as e: - logger.error(f"提取 Excel 文本失败: {e}") - return f"[Excel 提取失败: {str(e)}]" + logger.error(f"提取 xls 文本失败: {e}") + return f"[xls 提取失败: {str(e)}]" - @staticmethod - async def _extract_csv_text(file_content: bytes) -> str: + async def _extract_csv_text(self, file_content: bytes) -> str: """提取 CSV 文本""" try: - text = file_content.decode('utf-8-sig') + text = self._decode_text_safe(file_content) reader = csv.reader(io.StringIO(text)) return '\n'.join('\t'.join(row) for row in reader) except Exception as e: logger.error(f"提取 CSV 文本失败: {e}") return f"[CSV 提取失败: {str(e)}]" - @staticmethod - async def _extract_json_text(file_content: bytes) -> str: + async def _extract_json_text(self, file_content: bytes) -> str: """提取 JSON 文本""" try: - data = json.loads(file_content.decode('utf-8')) + text = self._decode_text_safe(file_content) + data = json.loads(text) return json.dumps(data, ensure_ascii=False, indent=2) except Exception as e: logger.error(f"提取 JSON 文本失败: {e}") return f"[JSON 提取失败: {str(e)}]" + def _is_word_file(self, file_content: bytes, mime_type: str) -> bool: + """判断是不是 Word 文件(doc / docx),不依赖后缀""" + # 旧版 .doc + if mime_type == 'application/msword': + return True + + # 新版 .docx(ZIP 内部包含 word/document.xml) + header = file_content[:4] + if header == b'PK\x03\x04': + try: + with zipfile.ZipFile(io.BytesIO(file_content)) as zf: + return "word/document.xml" in zf.namelist() + except: + pass + + return False + + def _is_excel_file(self, file_content: bytes, mime_type: str) -> bool: + """判断是不是 Excel 文件(xls / xlsx),不依赖后缀""" + # 旧版 .xls + if mime_type == 'application/vnd.ms-excel': + return True + + # 新版 .xlsx(ZIP 内部包含 xl/workbook.xml) + header = file_content[:4] + if header == b'PK\x03\x04': + try: + with zipfile.ZipFile(io.BytesIO(file_content)) as zf: + return "xl/workbook.xml" in zf.namelist() + except: + pass + + return False + + @staticmethod + def _decode_text_safe(file_content: bytes) -> str: + """ + 【万能文本解码】 + 自动检测编码,支持 utf-8 / gbk / gb2312 / utf-8-sig / ascii 等 + 永远不报错,永远不乱码 + """ + if not file_content: + return "" + + # 1. 自动检测文件编码 + detect = chardet.detect(file_content) + encoding = detect.get("encoding") or "utf-8" + encoding = encoding.lower() + + # 2. 兼容常见中文编码 + compatible_encodings = ["utf-8", "gbk", "gb18030", "gb2312", "ascii", "latin-1"] + + # 3. 按优先级尝试解码 + for enc in [encoding] + compatible_encodings: + if not enc: + continue + try: + return file_content.decode(enc.strip()) + except (UnicodeDecodeError, LookupError): + continue + + # 终极兜底 + return file_content.decode("utf-8", errors="replace") + def get_multimodal_service(db: Session) -> MultimodalService: """获取多模态服务实例(依赖注入)""" diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index d5d19e0d..12e0c324 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1408,12 +1408,11 @@ async def analytics_memory_types( if end_user_id: try: conversation_repo = ConversationRepository(db) - conversations = conversation_repo.get_conversation_by_user_id( + conversations, total = conversation_repo.get_conversation_by_user_id( user_id=uuid.UUID(end_user_id), - limit=100, # 获取更多会话以准确统计 is_activate=True ) - work_count = len(conversations) + work_count = total logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})") except Exception as e: logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}") diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 04a778a1..aee3d75f 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -25,7 +25,7 @@ from app.repositories.workflow_repository import ( WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) -from app.schemas import DraftRunRequest, FileInput, FileType +from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService @@ -55,6 +55,7 @@ class WorkflowService: edges: list[dict[str, Any]], variables: list[dict[str, Any]] | None = None, execution_config: dict[str, Any] | None = None, + features: dict[str, Any] | None = None, triggers: list[dict[str, Any]] | None = None, validate: bool = True ) -> WorkflowConfig: @@ -66,6 +67,7 @@ class WorkflowService: edges: 边列表 variables: 变量列表 execution_config: 执行配置 + features: 功能特性 triggers: 触发器列表 validate: 是否验证配置 @@ -81,6 +83,7 @@ class WorkflowService: "edges": edges, "variables": variables or [], "execution_config": execution_config or {}, + "features": features or {}, "triggers": triggers or [] } @@ -101,6 +104,7 @@ class WorkflowService: edges=edges, variables=variables, execution_config=execution_config, + features=features, triggers=triggers ) diff --git a/api/app/tasks.py b/api/app/tasks.py index f5258330..3a237d82 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2675,13 +2675,15 @@ def write_perceptual_memory( time_limit=7200, # 2小时硬超时 soft_time_limit=6900, ) -def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: +def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]: """触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。 由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。 + 任务完成且所有用户数据均完整时,写入 Redis 标记,避免下次重复投递。 Args: end_user_ids: 需要检查的用户 ID 列表 + workspace_id: 工作空间 ID,用于完成标记 Returns: 包含任务执行结果的字典 @@ -2707,6 +2709,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s # 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置) user_llm_map: Dict[str, Optional[str]] = {} + user_embedding_map: Dict[str, Optional[str]] = {} try: with get_db_context() as db: from app.services.memory_agent_service import get_end_users_connected_configs_batch @@ -2718,21 +2721,54 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s try: cfg = MemoryConfigService(db).load_memory_config(config_id=config_id) user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None + user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None except Exception as e: - logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}") + logger.warning(f"[CommunityCluster] 用户 {uid} 加载配置失败,将使用 None: {e}") user_llm_map[uid] = None + user_embedding_map[uid] = None else: user_llm_map[uid] = None + user_embedding_map[uid] = None except Exception as e: - logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}") + logger.warning(f"[CommunityCluster] 批量获取配置失败,所有用户将使用 None: {e}") for end_user_id in end_user_ids: try: - # 已有社区节点则跳过 + # 已有社区节点时,检查是否存在属性不完整的节点 has_communities = await repo.has_communities(end_user_id) if has_communities: - skipped += 1 - logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过") + llm_model_id = user_llm_map.get(end_user_id) + embedding_model_id = user_embedding_map.get(end_user_id) + incomplete_ids = await repo.get_incomplete_communities( + end_user_id, check_embedding=bool(embedding_model_id) + ) + if not incomplete_ids: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 社区节点均完整,跳过") + continue + + # 对不完整的社区节点逐一补全元数据 + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id, + ) + logger.info( + f"[CommunityCluster] 用户 {end_user_id} 发现 {len(incomplete_ids)} 个属性不完整的社区,开始补全" + ) + patch_ok = 0 + patch_fail = 0 + for cid in incomplete_ids: + try: + await engine._generate_community_metadata(cid, end_user_id) + patch_ok += 1 + except Exception as patch_err: + patch_fail += 1 + logger.error(f"[CommunityCluster] 社区 {cid} 元数据补全失败: {patch_err}") + logger.info( + f"[CommunityCluster] 用户 {end_user_id} 社区补全完成: 成功={patch_ok}, 失败={patch_fail}" + ) + initialized += 1 continue # 检查是否有 ExtractedEntity 节点 @@ -2742,11 +2778,13 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过") continue - # 每个用户使用自己的 llm_model_id + # 每个用户使用自己的 llm_model_id / embedding_model_id llm_model_id = user_llm_map.get(end_user_id) + embedding_model_id = user_embedding_map.get(end_user_id) engine = LabelPropagationEngine( connector=connector, llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id, ) logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") diff --git a/api/app/version_info.json b/api/app/version_info.json index 12793cb5..b4f6976f 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,38 @@ { + "v0.2.8": { + "introduction": { + "codeName": "景玉", + "releaseDate": "2026-3-20", + "upgradePosition": "🐻 MemoryBear v0.2.8 社区版全面升级应用共享、多模态交互与平台基础设施,引入语音交互、感知记忆和云端存储,打造更强大的开放 AI 记忆平台", + "coreUpgrades": [ + "1. 应用共享与发布
* 应用共享(Agent、工作流、Agent 集群):全类型应用共享至其他空间
* 分享应用默认开启记忆功能:发布分享后记忆默认开启,关闭时提醒
* 工作流记忆分享规则:按记忆配置自动控制分享页记忆开关
* 分享会话联网搜索修复:恢复分享应用的联网搜索能力", + "2. 多模态与交互 💬
* 语音输入:模型接口和应用支持语音输入
* 语音回复:应用支持语音回复模态
* 多模态感知记忆:记忆系统支持视觉、音频、图片和文件的感知记忆
* 对话框文件展示:试运行和体验分享中正确展示上传文件", + "3. 平台与基础设施 ⚙️
* i18n 国际化:全面多语言多地区支持
* 云端文件存储(OSS + S3):支持阿里云 OSS 和 S3 云端上传
* Flower 容器监控:Celery 异步任务监控与管理", + "4. EndUser 身份迁移 🔐
* EndUser 从 app_id 迁移至 workspace_id:身份从应用级迁移至工作空间级", + "5. 情景记忆 🧠
* 情景记忆聚类算法:基于社区图谱的聚类算法,支持老用户图谱生成", + "6. 稳健性与缺陷修复 🔧
* MCP 服务删除后工具 404:修复删除 MCP 服务后接口报错
* 应用导出配置不一致:导出已保存配置而非画布状态
* 工作流节点 ID 重复:修复复制节点后 ID 冲突
* 条件分支连线错误:修复保存刷新后连线错乱
* 回复节点内容丢失:修复点击画布后内容消失
* 连接桩规则优化:禁止非法连接方向
* 知识库状态列宽度:锁定或自适应宽度
* 等待中文档预览:支持未完成解析文档预览
* 知识库关联修复:统一修复关联问题
* 多模态对话连续性:修复多模态内容后无法继续对话
* 时区统一:环境变量统一控制存储和任务时区
* 遗忘强度精度:修复小数显示过长", + "
", + "v0.2.8 社区版在应用共享和多模态交互方面实现重大升级,感知记忆扩展了平台的认知维度。后续将深化多智能体协作、情景记忆聚类,并持续优化平台稳定性与开放生态。", + "MemoryBear —— 让 AI 拥有记忆 🐻✨" + ] + }, + "introduction_en": { + "codeName": "JingYu", + "releaseDate": "2026-3-20", + "upgradePosition": "🐻 MemoryBear v0.2.8 Community delivers multimodal interaction, perceptual memory, cloud storage, and workspace-level identity for a more capable open AI memory platform", + "coreUpgrades": [ + "1. Application Sharing & Publishing
* Application Sharing (Agent, Workflow, Agent Cluster): Full sharing across all app types
* Memory Enabled by Default: Memory auto-enabled on shared apps with disable reminder
* Workflow Memory Sharing Rules: Auto-controlled based on memory configuration
* Shared Session Web Search Fix: Restored web search for shared apps", + "2. Multimodal & Interaction 💬
* Voice Input: Model interfaces and apps support voice input
* Voice Reply: Apps support voice reply modality
* Multimodal Perceptual Memory: Memory system supports visual, audio, image, and file perception
* File Display in Chat: Uploaded files display correctly in dry-run and sharing", + "3. Platform & Infrastructure ⚙️
* i18n Internationalization: Full multi-language multi-region support
* Cloud File Storage (OSS + S3): Alibaba Cloud OSS and S3 cloud uploads
* Flower Container Monitoring: Celery async task monitoring and management", + "4. EndUser Identity Migration 🔐
* EndUser Migration from app_id to workspace_id: Identity migrated to workspace level", + "5. Episodic Memory 🧠
* Episodic Memory Clustering: Community-graph-based clustering with legacy user support", + "6. Robustness & Bug Fixes 🔧
* MCP Service Deletion 404: Fixed tool endpoint error after MCP removal
* App Export Config Mismatch: Exports saved config instead of canvas state
* Workflow Duplicate Node ID: Fixed ID conflict on node duplication
* Conditional Branch Wiring: Fixed wiring reset after save/refresh
* Reply Node Content Loss: Fixed content disappearing on canvas click
* Port Connection Rules: Prohibited invalid connection directions
* Knowledge Base Status Width: Locked or adaptive column width
* Pending Document Preview: Preview support for unparsed documents
* Knowledge Base Association Fixes: Consolidated association fixes
* Multimodal Conversation Continuity: Fixed single-round limit after multimodal input
* Timezone Unification: Env-var controlled unified timezone
* Forgetting Strength Precision: Fixed excessive decimal display", + "
", + "v0.2.8 Community delivers major upgrades in application sharing and multimodal interaction, with perceptual memory expanding the platform's cognitive dimensions. Multi-agent collaboration, episodic clustering, and continued platform stability improvements are ahead.", + "MemoryBear — Give AI Memory 🐻✨" + ] + } + }, "v0.2.7": { "introduction": { "codeName": "武陵", diff --git a/api/tests/workflow/executor/test_vairable_pool.py b/api/tests/workflow/executor/test_vairable_pool.py index 3404eb79..0ba4d259 100644 --- a/api/tests/workflow/executor/test_vairable_pool.py +++ b/api/tests/workflow/executor/test_vairable_pool.py @@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default(): """测试获取不存在的节点输出(使用默认值)""" pool = VariablePool() - result = pool.get_node_output("nonexistent_node", defalut=None, strict=False) + result = pool.get_node_output("nonexistent_node", default=None, strict=False) assert result is None diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index 60ed2403..63ec80ae 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -52,6 +52,10 @@ export const getKnowledgeBaseTypeList = async (): Promise => { // 如果不是数组,返回空数组 return []; }; +// 获取文件地址 +export const getFileUrl = (fileId: string) => { + return `${apiPrefix}/files/${fileId}`; +}; // 知识库文档解析类型 export const getKnowledgeBaseDocumentParseTypeList = async () => { const response = await request.get(`${apiPrefix}/knowledges/parsertype`); diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 823e3d78..9a464893 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-13 10:48:41 + * @Last Modified time: 2026-03-19 18:35:10 */ import { request } from '@/utils/request' import type { AxiosRequestConfig } from 'axios' @@ -218,8 +218,8 @@ export const getExplicitMemory = (end_user_id: string) => { export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => { return request.post(`/memory/explicit-memory/details`, data) } -export const getConversations = (end_user_id: string) => { - return request.get(`/memory/work/${end_user_id}/conversations`) +export const getConversations = (end_user_id: string, page = 1, pagesize = 20) => { + return request.get(`/memory/work/${end_user_id}/conversations`, { page, pagesize }) } export const getConversationMessages = (end_user_id: string, conversation_id: string) => { return request.get(`/memory/work/${end_user_id}/messages`, { conversation_id }) diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 2824381e..aa6f28bd 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -143,15 +143,20 @@ const ChatContent: FC = ({ } return (
handleDownload(file)}> - {(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) &&
} - {(file.type.includes('pdf')) &&
} - {(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) &&
} + {(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) + ?
+ :(file.type.includes('pdf')) + ?
+ : (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) + ?
+ : null + }
) })} diff --git a/web/src/components/Chat/ChatToolbar.tsx b/web/src/components/Chat/ChatToolbar.tsx index 64c3f03e..3816f790 100644 --- a/web/src/components/Chat/ChatToolbar.tsx +++ b/web/src/components/Chat/ChatToolbar.tsx @@ -49,6 +49,7 @@ interface FormValues { memory?: boolean; } +const max_file_count = 1; const ChatToolbar = forwardRef(({ features, leftExtra, @@ -86,10 +87,16 @@ const ChatToolbar = forwardRef(({ // Append newly uploaded file to the file list when upload is complete const fileChange = (file?: any) => { - if (file?.status !== 'done') return - const files = [...(queryValues?.files || []), file] - form.setFieldValue('files', files) - onFilesChange?.(files) + console.log('file', file) + const lastFiles = form.getFieldValue('files') || []; + const index = lastFiles.findIndex((item: any) => item.uid === file.uid) + if (index > -1) { + lastFiles[index] = file + } else { + lastFiles.push(file) + } + form.setFieldValue('files', [...lastFiles]) + onFilesChange?.([...lastFiles]) } // Append recorded audio file to the file list and notify parent @@ -129,8 +136,8 @@ const ChatToolbar = forwardRef(({ key: 'url', label: t('memoryConversation.addRemoteFile'), onClick: () => { - if ((queryValues?.files?.length || 0) >= file_upload.max_file_count) { - messageApi.warning(t('common.fileNumTip', { num: file_upload.max_file_count })) + if ((queryValues?.files?.length || 0) >= max_file_count) { + messageApi.warning(t('common.fileNumTip', { num: max_file_count })) return } uploadFileListModalRef.current?.handleOpen() @@ -146,7 +153,7 @@ const ChatToolbar = forwardRef(({ onChange={fileChange} requestConfig={uploadRequestConfig} featureConfig={file_upload} - disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count} + disabled={(queryValues?.files?.length || 0) >= max_file_count} /> ) }) @@ -184,7 +191,7 @@ const ChatToolbar = forwardRef(({ {rightExtra} {file_upload?.audio_enabled && file_upload?.allowed_transfer_methods?.includes('local_file') && = file_upload.max_file_count} + disabled={(queryValues?.files?.length || 0) >= max_file_count} action={uploadAction} requestConfig={uploadRequestConfig} onRecordingComplete={handleRecordingComplete} diff --git a/web/src/components/DocumentPreview/index.tsx b/web/src/components/DocumentPreview/index.tsx index 247f713e..f659c53e 100644 --- a/web/src/components/DocumentPreview/index.tsx +++ b/web/src/components/DocumentPreview/index.tsx @@ -4,7 +4,7 @@ * @Author: yujiangping * @Date: 2026-03-16 19:01:12 * @LastEditors: yujiangping - * @LastEditTime: 2026-03-18 18:35:53 + * @LastEditTime: 2026-03-20 12:12:20 */ import { useState, useEffect, useRef, useCallback, type FC } from 'react'; import { Spin, Alert, Button, Table, InputNumber, Image } from 'antd'; @@ -309,23 +309,64 @@ const DocumentPreview: FC = ({ } }; + const [csvTruncated, setCsvTruncated] = useState(false); + const isCsvFile = () => getFileExtension() === '.csv'; + // CSV 预览大小限制:1MB + const CSV_PREVIEW_SIZE = 1 * 1024 * 1024; + // 最大预览行数 + const MAX_PREVIEW_ROWS = 500; + + const fetchFileBufferWithLimit = async (url: string, maxBytes?: number): Promise => { + const requestUrl = getRequestUrl(url); + const headers: Record = { + 'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`, + }; + if (maxBytes) { + headers['Range'] = `bytes=0-${maxBytes - 1}`; + } + const response = await fetch(requestUrl, { + credentials: 'include', + headers, + }); + if (!response.ok && response.status !== 206) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + return response.arrayBuffer(); + }; + const loadExcelFile = async () => { setLoading(true); setError(false); setErrorMessage(''); + setCsvTruncated(false); try { - const arrayBuffer = await fetchFileBuffer(fileUrl); - - // CSV 文件需要处理编码问题(可能是 GBK/GB2312) + // CSV 文件需要处理编码问题(可能是 GBK/GB2312),且大文件只取前 1MB if (isCsvFile()) { + let arrayBuffer: ArrayBuffer; + let truncated = false; + try { + // 先尝试 Range 请求只取前 1MB + arrayBuffer = await fetchFileBufferWithLimit(fileUrl, CSV_PREVIEW_SIZE); + // 如果返回的数据刚好等于限制大小,说明可能被截断了 + if (arrayBuffer.byteLength >= CSV_PREVIEW_SIZE) { + truncated = true; + } + } catch { + // Range 请求不支持时,全量获取后截断 + const fullBuffer = await fetchFileBuffer(fileUrl); + if (fullBuffer.byteLength > CSV_PREVIEW_SIZE) { + arrayBuffer = fullBuffer.slice(0, CSV_PREVIEW_SIZE); + truncated = true; + } else { + arrayBuffer = fullBuffer; + } + } + let csvText: string; - // 先尝试 UTF-8 解码 const utf8Text = new TextDecoder('utf-8').decode(arrayBuffer); - // 检测是否有乱码特征(常见的 GBK 被错误解析为 UTF-8 的替换字符) if (utf8Text.includes('\uFFFD') || /[\x80-\xff]/.test(utf8Text.slice(0, 200))) { - // 尝试 GBK 解码 try { csvText = new TextDecoder('gbk').decode(arrayBuffer); } catch { @@ -334,19 +375,35 @@ const DocumentPreview: FC = ({ } else { csvText = utf8Text; } + + // 如果被截断,去掉最后一行不完整的数据 + if (truncated) { + const lastNewline = csvText.lastIndexOf('\n'); + if (lastNewline > 0) { + csvText = csvText.substring(0, lastNewline); + } + } + const workbook = XLSX.read(csvText, { type: 'string' }); const sheets = workbook.SheetNames.map(sheetName => { const worksheet = workbook.Sheets[sheetName]; - const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][]; + let data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][]; + // 限制最大行数 + if (data.length > MAX_PREVIEW_ROWS + 1) { + data = data.slice(0, MAX_PREVIEW_ROWS + 1); // +1 保留表头 + truncated = true; + } return { sheetName, data }; }); + setCsvTruncated(truncated); setExcelData(sheets); setLoading(false); return; } + const arrayBuffer = await fetchFileBuffer(fileUrl); const workbook = XLSX.read(arrayBuffer, { type: 'array' }); - const sheets = workbook.SheetNames.map(sheetName => { + const sheets = workbook.SheetNames.map((sheetName: string) => { const worksheet = workbook.Sheets[sheetName]; const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][]; return { sheetName, data }; @@ -522,9 +579,14 @@ const DocumentPreview: FC = ({ ) )} - {/* Excel 预览 */} + {/* Excel/CSV 预览 */} {isExcelFile() && !error && !loading && (
+ {csvTruncated && ( +
+ 文件较大,仅预览前 {MAX_PREVIEW_ROWS} 行数据 +
+ )} {excelData.map((sheet, index) => (

{sheet.sheetName}

@@ -541,6 +603,7 @@ const DocumentPreview: FC = ({ scroll={{ x: 'max-content' }} size="small" bordered + virtual /> )}
diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 47de99a8..7b7900f3 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -469,6 +469,7 @@ export const en = { download: 'Download', view: 'View', updated_at: 'Updated At', + callbackUrlInvalid: 'Please enter a valid URL', }, model: { searchPlaceholder: 'search model…', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 9fa8cc0d..06a3bd74 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1106,6 +1106,7 @@ export const zh = { download: '下载', view: '查看', updated_at: '更新时间', + callbackUrlInvalid: '请输入有效的 URL', }, model: { searchPlaceholder: '搜索模型…', diff --git a/web/src/views/ApplicationConfig/TestChat/index.tsx b/web/src/views/ApplicationConfig/TestChat/index.tsx index ad7931e2..c324622d 100644 --- a/web/src/views/ApplicationConfig/TestChat/index.tsx +++ b/web/src/views/ApplicationConfig/TestChat/index.tsx @@ -183,7 +183,7 @@ const TestChat: FC = ({ const handleSend = () => { if (loading || !application || !message || !message?.trim()) return - const files = toolbarRef.current?.getFiles() || [] + const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status)) const variables = toolbarRef.current?.getVariables() || [] const { isCanSend, params } = buildVariableParams(variables) if (!isCanSend) return @@ -235,7 +235,7 @@ const TestChat: FC = ({ const handleWorkflowSend = () => { if (loading || !application || !message || !message?.trim()) return - const files = toolbarRef.current?.getFiles() || [] + const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status)) const variables = toolbarRef.current?.getVariables() || [] const { isCanSend, params } = buildVariableParams(variables) if (!isCanSend) return diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index 56e1088b..38225104 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -189,7 +189,7 @@ const Chat: FC = ({ .then(() => { const message = msg if (!message?.trim()) return - const files = toolbarRef.current?.getFiles() || [] + const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status)) // Validate required variables before sending let isCanSend = true const params: Record = {} @@ -350,7 +350,7 @@ const Chat: FC = ({ .then(() => { const message = msg if (!message || message.trim() === '') return - const files = toolbarRef.current?.getFiles() || [] + const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status)) addUserMessage(message, files) setMessage(undefined) toolbarRef.current?.setFiles([]) diff --git a/web/src/views/ApplicationConfig/components/FeaturesConfig/FeaturesConfigModal.tsx b/web/src/views/ApplicationConfig/components/FeaturesConfig/FeaturesConfigModal.tsx index 5fcb752d..d712720f 100644 --- a/web/src/views/ApplicationConfig/components/FeaturesConfig/FeaturesConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/FeaturesConfig/FeaturesConfigModal.tsx @@ -24,7 +24,7 @@ interface FeaturesConfigModalProps { refresh: (value: FeaturesConfigForm) => void; source?: Application['type']; } - +const max_file_count = 1; /** * Modal for copying applications */ @@ -133,7 +133,7 @@ const FeaturesConfigModal = forwardRef
{t('application.maxCount')}
- {fu.max_file_count} {t('application.unix')} + {max_file_count} {t('application.unix')}
diff --git a/web/src/views/ApplicationConfig/components/FeaturesConfig/FileUploadSettingModal.tsx b/web/src/views/ApplicationConfig/components/FeaturesConfig/FileUploadSettingModal.tsx index 3fb05a0e..f33b313b 100644 --- a/web/src/views/ApplicationConfig/components/FeaturesConfig/FileUploadSettingModal.tsx +++ b/web/src/views/ApplicationConfig/components/FeaturesConfig/FileUploadSettingModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-03-05 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-19 15:18:20 + * @Last Modified time: 2026-03-19 20:19:14 */ import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, InputNumber, Flex, Switch, Row, Col, Radio } from 'antd'; @@ -82,28 +82,27 @@ const defaultValues: FileUpload = { "mp3", "wav", "m4a", - "ogg", - "flac" ], document_enabled: false, document_max_size_mb: 100, document_allowed_extensions: [ "pdf", "docx", + "doc", "xlsx", + "xls", "txt", "csv", - "json" + "json", + "md", ], video_enabled: false, video_max_size_mb: 100, video_allowed_extensions: [ "mp4", "mov", - "avi", - "webm" ], - max_file_count: 5, + max_file_count: 1, allowed_transfer_methods: 'both' } @@ -168,8 +167,8 @@ const FileUploadSettingModal = forwardRef -
{t('application.maxCount')}
- + {/*
{t('application.maxCount')}
*/} + diff --git a/web/src/views/Conversation/components/FileUpload.tsx b/web/src/views/Conversation/components/FileUpload.tsx index 166b00c8..8c646bea 100644 --- a/web/src/views/Conversation/components/FileUpload.tsx +++ b/web/src/views/Conversation/components/FileUpload.tsx @@ -23,7 +23,7 @@ import { useState, useEffect, forwardRef, useImperativeHandle, useMemo } from 'react'; import { Upload, Progress, App, Flex } from 'antd'; import type { UploadProps, UploadFile } from 'antd'; -import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface'; +import type { UploadProps as RcUploadProps, RcFile, UploadFileStatus } from 'antd/es/upload/interface'; import { useTranslation } from 'react-i18next'; import { request } from '@/utils/request' @@ -221,17 +221,29 @@ const UploadFiles = forwardRef(({ */ const handleCustomRequest: RcUploadProps['customRequest'] = async (options) => { const { file, onSuccess, onError } = options; - - try { - const formData = new FormData(); - formData.append('file', file); - - const response = await request.uploadFile(action, formData, requestConfig); - - onSuccess?.({data: response}); - } catch (error) { - onError?.(error as Error); + if (typeof file === 'string') return; + const rcFile = file as RcFile; + const formData = new FormData(); + formData.append('file', rcFile); + const fileVo: UploadFile = { + uid: rcFile.uid, + name: rcFile.name, + status: 'uploading' as UploadFileStatus, + percent: 0, + type: rcFile.type, + originFileObj: rcFile, + thumbUrl: URL.createObjectURL(rcFile) } + onChange?.(fileVo) + request.uploadFile(action, formData, requestConfig) + .then(res => { + onSuccess?.({ data: res }); + }) + .catch((error) => { + onError?.(error as Error); + fileVo.status = 'error' + onChange?.(fileVo) + }) }; /** diff --git a/web/src/views/Conversation/components/UploadFileListModal.tsx b/web/src/views/Conversation/components/UploadFileListModal.tsx index ce71066d..4d2e83ee 100644 --- a/web/src/views/Conversation/components/UploadFileListModal.tsx +++ b/web/src/views/Conversation/components/UploadFileListModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:09:47 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-18 21:10:01 + * @Last Modified time: 2026-03-19 20:32:32 */ /** * Upload File List Modal Component @@ -19,7 +19,10 @@ * @component */ import { forwardRef, useImperativeHandle, useState, useMemo } from 'react'; -import { Form, Input, Select, Button, Flex } from 'antd'; +import { Form, Input, Select, + // Button, + Flex +} from 'antd'; import { useTranslation } from 'react-i18next'; import type { UploadFileListModalRef } from '../types' @@ -105,9 +108,11 @@ const UploadFileListModal = forwardRef -
+ - {(fields, { add, remove }) => ( + {(fields, + // { add, remove } + ) => ( <> {/* Render each file entry with type selector and URL input */} {fields.map(({ key, name, ...restField }) => ( @@ -116,6 +121,9 @@ const UploadFileListModal = forwardRef -
remove(name)} - >
+ >
*/} ))} - + {/* - + */} )} diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index 55d17073..3f822917 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -194,7 +194,7 @@ const Conversation: FC = () => { /** Send message and handle streaming response */ const handleSend = () => { if (!token || !shareToken) return - const files = toolbarRef.current?.getFiles() || [] + const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status)) const variables = toolbarRef.current?.getVariables() || [] let isCanSend = true const params: Record = {} diff --git a/web/src/views/KnowledgeBase/[knowledgeBaseId]/DocumentDetails.tsx b/web/src/views/KnowledgeBase/[knowledgeBaseId]/DocumentDetails.tsx index 4b52b7fe..8859a8c8 100644 --- a/web/src/views/KnowledgeBase/[knowledgeBaseId]/DocumentDetails.tsx +++ b/web/src/views/KnowledgeBase/[knowledgeBaseId]/DocumentDetails.tsx @@ -11,7 +11,7 @@ import { useNavigate, useParams, useLocation } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager'; import { Button, Spin, message, Switch } from 'antd'; -import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk } from '@/api/knowledgeBase'; +import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk, getFileUrl } from '@/api/knowledgeBase'; import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types'; import { formatDateTime } from '@/utils/format'; import InfoPanel, { type InfoItem } from '../components/InfoPanel'; @@ -138,7 +138,7 @@ const DocumentDetails: FC = () => { const response = await getDocumentDetail(documentId); setDocument(response); setInfoItems(formatDocumentInfo(response)); - const url = `${imagePath}/api/files/${response.file_id}` + const url = `${window.location.origin}/api/files/${response.file_id}`; setFileUrl(url); setParserMode(response?.parser_config?.auto_questions || 0) // ChunkList will be called automatically in useEffect based on document.progress diff --git a/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx b/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx index 6ffc61e7..df7d639e 100644 --- a/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx +++ b/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx @@ -191,24 +191,28 @@ const RelationshipNetwork: FC = () => { })}> {(selectedNode as RawCommunityNode).properties.community_id ?
-
- {(selectedNode as RawCommunityNode).properties.name} -
-
{t('userMemory.summary')}
-
- {(selectedNode as RawCommunityNode).properties.summary} -
- - {t('userMemory.member_count')} - {(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')} - +
+ {(selectedNode as RawCommunityNode).properties.name || selectedNode.id} +
+ {(selectedNode as RawCommunityNode).properties.summary && <> +
{t('userMemory.summary')}
+
+ {(selectedNode as RawCommunityNode).properties.summary} +
+ } + + {t('userMemory.member_count')} + {(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')} + - -
{t('userMemory.core_entities')}
-
    - {(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) =>
  • {entity}
  • )} -
-
+ {(selectedNode as RawCommunityNode).properties.core_entities && <> + +
{t('userMemory.core_entities')}
+
    + {(selectedNode as RawCommunityNode).properties.core_entities?.map((entity, index) =>
  • {entity}
  • )} +
+ } + : <> {(selectedNode as Node).name &&
diff --git a/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx b/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx index ea859dce..edb9e526 100644 --- a/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx @@ -4,12 +4,14 @@ * @Last Modified by: ZhaoYing * @Last Modified time: 2026-03-16 15:10:17 */ -import { type FC, useEffect, useState, useMemo } from 'react' +import { type FC, useEffect, useState, useMemo, useRef } from 'react' import clsx from 'clsx' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' import { Row, Col, Skeleton, Button, Divider, Tooltip, Flex } from 'antd' + +import InfiniteScroll from 'react-infinite-scroll-component' import RbCard from '@/components/RbCard/Card' import { getConversations, @@ -61,6 +63,8 @@ const WorkingDetail: FC = () => { const { id } = useParams() const [loading, setLoading] = useState(false) const [data, setData] = useState([]) + const [hasMore, setHasMore] = useState(true) + const pageRef = useRef(1) const [messagesLoading, setMessagesLoading] = useState(false) const [messages, setMessages] = useState([]) const [detailLoading, setDetailLoading] = useState(false) @@ -80,17 +84,30 @@ const WorkingDetail: FC = () => { setSelected(null) setDetail(null) setData([]) - getConversations(id).then((res) => { - const response = res as Conversation[] - setData(response) - setSelected(response[0] || null) + setHasMore(true) + pageRef.current = 1 + getConversations(id, 1).then((res) => { + const response = res as { items: Conversation[], page: { hasnext: boolean } } + setData(response.items) + setSelected(response.items[0] || null) + setHasMore(response.page.hasnext) }) .finally(() => { setLoading(false) }) } - /* Load messages and AI insight whenever the selected conversation changes. */ + const loadMore = () => { + if (!id) return + const nextPage = pageRef.current + 1 + getConversations(id, nextPage).then((res) => { + const response = res as {items: Conversation[], page: { hasnext: boolean }} + setData(prev => [...prev, ...response.items]) + pageRef.current = nextPage + setHasMore(response.page.hasnext) + }) + } + useEffect(() => { if (!id || !selected || !selected.id) return getDetail(selected.id) @@ -138,16 +155,16 @@ const WorkingDetail: FC = () => { : data.length === 0 ? :( - - - - + + +
+ {data.map(item => ( { ))} - - + +
{selected && <> diff --git a/web/src/views/Workflow/components/Chat/Chat.tsx b/web/src/views/Workflow/components/Chat/Chat.tsx index 8832cb1d..830c277c 100644 --- a/web/src/views/Workflow/components/Chat/Chat.tsx +++ b/web/src/views/Workflow/components/Chat/Chat.tsx @@ -151,7 +151,7 @@ const Chat = forwardRef !['uploading', 'error'].includes(item.status)) setChatList(prev => [...prev, { role: 'user', content: message, diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index b263120a..8fe29d19 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -18,8 +18,8 @@ const InitialValuePlugin: React.FC = ({ value, options const isUserInputRef = useRef(false); useEffect(() => { - // 监听编辑器变化,标记是否为用户输入 - const removeListener = editor.registerUpdateListener(({ editorState }) => { + const removeListener = editor.registerUpdateListener(({ editorState, tags }) => { + if (tags.has('programmatic')) return; editorState.read(() => { const root = $getRoot(); const textContent = root.getTextContent(); @@ -107,7 +107,7 @@ const InitialValuePlugin: React.FC = ({ value, options }); root.append(paragraph); } - }, { discrete: true }); + }, { discrete: true, tag: 'programmatic' }); }); }