From d423e80ddb1cd379b372c0e1fa167ba4d4e0c992 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Tue, 23 Dec 2025 17:45:37 +0800 Subject: [PATCH 01/13] feat(workflow): support multi-variable assignment in assigner node --- .../core/workflow/nodes/assigner/config.py | 21 +++-- api/app/core/workflow/nodes/assigner/node.py | 90 +++++++++---------- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/api/app/core/workflow/nodes/assigner/config.py b/api/app/core/workflow/nodes/assigner/config.py index 1cb0def3..03302af4 100644 --- a/api/app/core/workflow/nodes/assigner/config.py +++ b/api/app/core/workflow/nodes/assigner/config.py @@ -1,21 +1,32 @@ -from pydantic import Field +from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.enums import AssignmentOperator -class AssignerNodeConfig(BaseNodeConfig): +class AssignmentItem(BaseModel): + """ + Single assignment definition. + """ + variable_selector: str | list[str] = Field( ..., - description="Variables to be assigned", + description="Target variable name(s) to assign", ) operation: AssignmentOperator = Field( ..., - description="Operator to assign", + description="Assignment operator", ) value: str | list[str] = Field( ..., - description="Values to assign", + description="Value(s) to assign to the variable(s)", + ) + + +class AssignerNodeConfig(BaseNodeConfig): + assignments: list[AssignmentItem] = Field( + ..., + description="List of variable assignment definitions", ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index eb32bf8b..b8b7c1f4 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -29,52 +29,52 @@ class AssignerNode(BaseNode): """ # Initialize a variable pool for accessing conversation, node, and system variables pool = VariablePool(state) + for assignment in self.typed_config.assignments: + # Get the target variable selector (e.g., "conv.test") + variable_selector = assignment.variable_selector + if isinstance(variable_selector, str): + # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] + variable_selector = variable_selector.split('.') - # Get the target variable selector (e.g., "conv.test") - variable_selector = self.typed_config.variable_selector - if isinstance(variable_selector, str): - # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] - variable_selector = variable_selector.split('.') + # Only conversation variables ('conv') are allowed + if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) + raise ValueError("Only conversation variables can be assigned.") - # Only conversation variables ('conv') are allowed - if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) - raise ValueError("Only conversation variables can be assigned.") + # Get the value or expression to assign + value = assignment.value + if isinstance(value, list): + value = '.'.join(value) + value = ExpressionEvaluator.evaluate( + expression=value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) - # Get the value or expression to assign - value = self.typed_config.value - if isinstance(value, list): - value = '.'.join(value) - value = ExpressionEvaluator.evaluate( - expression=value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) + # Select the appropriate assignment operator instance based on the target variable type + operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( + pool, variable_selector, value + ) - # Select the appropriate assignment operator instance based on the target variable type - operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( - pool, variable_selector, value - ) - - # Execute the configured assignment operation - match self.typed_config.operation: - case AssignmentOperator.ASSIGN: - operator.assign() - case AssignmentOperator.CLEAR: - operator.clear() - case AssignmentOperator.ADD: - operator.add() - case AssignmentOperator.SUBTRACT: - operator.subtract() - case AssignmentOperator.MULTIPLY: - operator.multiply() - case AssignmentOperator.DIVIDE: - operator.divide() - case AssignmentOperator.APPEND: - operator.append() - case AssignmentOperator.REMOVE_FIRST: - operator.remove_first() - case AssignmentOperator.REMOVE_LAST: - operator.remove_last() - case _: - raise ValueError(f"Invalid Operator: {self.typed_config.operation}") + # Execute the configured assignment operation + match assignment.operation: + case AssignmentOperator.ASSIGN: + operator.assign() + case AssignmentOperator.CLEAR: + operator.clear() + case AssignmentOperator.ADD: + operator.add() + case AssignmentOperator.SUBTRACT: + operator.subtract() + case AssignmentOperator.MULTIPLY: + operator.multiply() + case AssignmentOperator.DIVIDE: + operator.divide() + case AssignmentOperator.APPEND: + operator.append() + case AssignmentOperator.REMOVE_FIRST: + operator.remove_first() + case AssignmentOperator.REMOVE_LAST: + operator.remove_last() + case _: + raise ValueError(f"Invalid Operator: {assignment.operation}") From 1f6abb29259b092f62fb53c28449fedb638392c5 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Wed, 24 Dec 2025 10:39:13 +0800 Subject: [PATCH 02/13] fix(docker-compose): update image names and restore container name usage for MCP server URL --- .../memory/agent/mcp_server/mcp_instance.py | 5 ++- api/docker-compose.yml | 44 +++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/api/app/core/memory/agent/mcp_server/mcp_instance.py b/api/app/core/memory/agent/mcp_server/mcp_instance.py index 3a2eeb78..c072a438 100644 --- a/api/app/core/memory/agent/mcp_server/mcp_instance.py +++ b/api/app/core/memory/agent/mcp_server/mcp_instance.py @@ -8,4 +8,7 @@ from mcp.server.fastmcp import FastMCP # Initialize FastMCP server instance # This instance is shared across all tool modules -mcp = FastMCP('data_flow') +mcp = FastMCP( + 'data_flow', + allowed_hosts=["mcp-server", "localhost"] +) \ No newline at end of file diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 48ec137d..e0919c3b 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -26,9 +26,35 @@ services: - default - celery + # FastAPI application - connects to MCP server + # MCP Server - standalone service + mcp-server: + image: redbear-mem-open:latest + container_name: mcp-server + ports: + - "8081:8081" # MCP server port + env_file: + - .env + environment: + - SERVER_IP=0.0.0.0 # Bind to all interfaces + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: python -m app.core.memory.agent.mcp_server.server + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 30s + restart: unless-stopped + networks: + - default + - celery + # FastAPI application - connects to MCP server api: - image: redbear-mem:latest + image: redbear-mem-open:latest container_name: api ports: - "8002:8000" @@ -37,6 +63,9 @@ services: environment: - MCP_SERVER_URL=http://mcp-server:8081 - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces + environment: + - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name + - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro @@ -51,15 +80,16 @@ services: # Celery worker - connects to MCP server worker: - image: redbear-mem:latest + image: redbear-mem-open:latest container_name: worker env_file: - .env environment: - - MCP_SERVER_URL=http://mcp-server:8081 + - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro + - /etc/localtime:/etc/localtime:ro command: celery -A app.celery_worker.celery_app worker --loglevel=info depends_on: mcp-server: @@ -67,5 +97,13 @@ services: restart: unless-stopped networks: - celery +networks: + celery: + depends_on: + mcp-server: + condition: service_healthy + restart: unless-stopped + networks: + - celery networks: celery: \ No newline at end of file From 11dc66960c27312105d2f218d49dd3e73acbd96d Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Wed, 24 Dec 2025 10:40:34 +0800 Subject: [PATCH 03/13] mcp quick fix --- api/docker-compose.yml | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/api/docker-compose.yml b/api/docker-compose.yml index e0919c3b..8470a5d1 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,32 +1,6 @@ version: '3.9' services: - # MCP Server - standalone service - mcp-server: - image: redbear-mem:latest - container_name: mcp-server - ports: - - "8081:8081" # MCP server port - env_file: - - .env - environment: - - SERVER_IP=0.0.0.0 # Bind to all interfaces - volumes: - - ./files:/files - - /etc/localtime:/etc/localtime:ro - command: python -m app.core.memory.agent.mcp_server.server - healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 30s - restart: unless-stopped - networks: - - default - - celery - - # FastAPI application - connects to MCP server # MCP Server - standalone service mcp-server: image: redbear-mem-open:latest @@ -60,9 +34,6 @@ services: - "8002:8000" env_file: - .env - environment: - - MCP_SERVER_URL=http://mcp-server:8081 - - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces environment: - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces @@ -89,7 +60,6 @@ services: volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro - - /etc/localtime:/etc/localtime:ro command: celery -A app.celery_worker.celery_app worker --loglevel=info depends_on: mcp-server: @@ -97,13 +67,5 @@ services: restart: unless-stopped networks: - celery -networks: - celery: - depends_on: - mcp-server: - condition: service_healthy - restart: unless-stopped - networks: - - celery networks: celery: \ No newline at end of file From 7e28ec0f229061ad7602d06cb0d0742b201fa7d8 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Wed, 24 Dec 2025 11:12:45 +0800 Subject: [PATCH 04/13] Refactor MCP server initialization and enhance Docker compatibility - Simplified FastMCP initialization by removing unnecessary allowed_hosts parameter. - Added logging for MCP server startup details. - Implemented DNS rebinding protection configuration to support Docker container hostnames. --- api/app/core/memory/agent/mcp_server/mcp_instance.py | 5 +---- api/app/core/memory/agent/mcp_server/server.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/api/app/core/memory/agent/mcp_server/mcp_instance.py b/api/app/core/memory/agent/mcp_server/mcp_instance.py index c072a438..3a2eeb78 100644 --- a/api/app/core/memory/agent/mcp_server/mcp_instance.py +++ b/api/app/core/memory/agent/mcp_server/mcp_instance.py @@ -8,7 +8,4 @@ from mcp.server.fastmcp import FastMCP # Initialize FastMCP server instance # This instance is shared across all tool modules -mcp = FastMCP( - 'data_flow', - allowed_hosts=["mcp-server", "localhost"] -) \ No newline at end of file +mcp = FastMCP('data_flow') diff --git a/api/app/core/memory/agent/mcp_server/server.py b/api/app/core/memory/agent/mcp_server/server.py index 18ea911f..f87ed529 100644 --- a/api/app/core/memory/agent/mcp_server/server.py +++ b/api/app/core/memory/agent/mcp_server/server.py @@ -147,6 +147,18 @@ def main(): # Get MCP port from environment (default: 8081) mcp_port = int(os.getenv("MCP_PORT", "8081")) + logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") + + # Configure DNS rebinding protection for Docker container compatibility + from mcp.server.fastmcp.server import TransportSecuritySettings + + # Disable DNS rebinding protection to allow Docker container hostnames + # This allows containers to connect using service names like 'mcp-server' + mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ) + logger.info("DNS rebinding protection: disabled for Docker container compatibility") + # logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") # Run the server with SSE transport for HTTP connections From 879b3da7ef5039ac3864eb31f8161a10eb535364 Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Wed, 24 Dec 2025 11:12:46 +0800 Subject: [PATCH 05/13] [fix] mineru parser --- api/app/core/rag/deepdoc/parser/mineru_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/core/rag/deepdoc/parser/mineru_parser.py b/api/app/core/rag/deepdoc/parser/mineru_parser.py index ec380922..fe6178ec 100644 --- a/api/app/core/rag/deepdoc/parser/mineru_parser.py +++ b/api/app/core/rag/deepdoc/parser/mineru_parser.py @@ -13,7 +13,7 @@ from io import BytesIO from os import PathLike from pathlib import Path from queue import Empty, Queue -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import numpy as np import pdfplumber @@ -439,7 +439,7 @@ class MinerUParser(RAGPdfParser): def parse_pdf( self, filepath: str | PathLike[str], - binary: BytesIO | bytes, + binary: Optional[Union[BytesIO, bytes]] = None, callback: Optional[Callable] = None, *, output_dir: Optional[str] = None, From ea9e5a689a6bdc6206d60db0e850d2ccb23fefa3 Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Wed, 24 Dec 2025 11:59:16 +0800 Subject: [PATCH 06/13] [ADD] Add RAG Setting --- api/env.example | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/api/env.example b/api/env.example index 1354233d..8ceb3934 100644 --- a/api/env.example +++ b/api/env.example @@ -71,6 +71,18 @@ ENABLE_SINGLE_SESSION= MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024 FILE_PATH=/files +# RAG Setting +DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 +HF_ENDPOINT=https://hf-mirror.com +MINERU_EXECUTABLE=mineru +MINERU_APISERVER=http://host.docker.internal:9987 +MINERU_OUTPUT_DIR=/files +MINERU_BACKEND=pipeline +MINERU_DELETE_OUTPUT=1 +TEXTLN_APISERVER=https://api.textin.com/ai/service/v1/pdf_to_markdown +TEXTLN_APP_ID= +TEXTLN_SECRET_CODE= + # VOLC ASR VOLC_APP_KEY= VOLC_ACCESS_KEY= From 8c4d31e4d589847e2d6b2ead5777138bd308158b Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:10:52 +0800 Subject: [PATCH 07/13] feat(workflow): implement a workflow node for knowledge base retrieval --- .../vdb/elasticsearch/elasticsearch_vector.py | 29 +++--- api/app/core/workflow/nodes/__init__.py | 4 +- api/app/core/workflow/nodes/configs.py | 4 +- .../core/workflow/nodes/knowledge/__init__.py | 4 + .../core/workflow/nodes/knowledge/config.py | 38 ++++++++ api/app/core/workflow/nodes/knowledge/node.py | 97 +++++++++++++++++++ api/app/core/workflow/nodes/node_factory.py | 6 +- api/app/core/workflow/nodes/start/node.py | 33 +++---- api/app/repositories/knowledge_repository.py | 2 +- api/app/services/knowledge_service.py | 2 +- 10 files changed, 179 insertions(+), 40 deletions(-) create mode 100644 api/app/core/workflow/nodes/knowledge/__init__.py create mode 100644 api/app/core/workflow/nodes/knowledge/config.py create mode 100644 api/app/core/workflow/nodes/knowledge/node.py diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 176f996a..198d1473 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector): self._client.indices.create(index=self._collection_name, body=index_mapping) -class ElasticSearchVectorFactory(ABC): - def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector: +class ElasticSearchVectorFactory: + @staticmethod + def init_vector(knowledge: Knowledge) -> ElasticSearchVector: collection_name = f"Vector_index_{knowledge.id}_Node" # Use regular Elasticsearch with config values @@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC): } ) - if knowledge.embedding and knowledge.reranker: - return ElasticSearchVector( - index_name=collection_name, - config=ElasticSearchConfig(**config_dict), - embedding_config=knowledge.embedding.api_keys[0], - reranker_config=knowledge.reranker.api_keys[0] - ) - else: - if knowledge.embedding is None: - raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") - if knowledge.reranker is None: - raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") + if knowledge.embedding is None: + raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") + if knowledge.reranker is None: + raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") + + return ElasticSearchVector( + index_name=collection_name, + config=ElasticSearchConfig(**config_dict), + embedding_config=knowledge.embedding.api_keys[0], + reranker_config=knowledge.reranker.api_keys[0] + ) + diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 1d00532e..fa5a5a2b 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.if_else import IfElseNode -# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode @@ -26,6 +26,6 @@ __all__ = [ "EndNode", "NodeFactory", "WorkflowNode", - # "KnowledgeRetrievalNode", + "KnowledgeRetrievalNode", "AssignerNode", ] diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index ecded070..e9f102f0 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig -# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig __all__ = [ @@ -30,6 +30,6 @@ __all__ = [ "AgentNodeConfig", "TransformNodeConfig", "IfElseNodeConfig", - # "KnowledgeRetrievalNodeConfig", + "KnowledgeRetrievalNodeConfig", "AssignerNodeConfig", ] diff --git a/api/app/core/workflow/nodes/knowledge/__init__.py b/api/app/core/workflow/nodes/knowledge/__init__.py new file mode 100644 index 00000000..25d0f00b --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"] diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py new file mode 100644 index 00000000..530116ff --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -0,0 +1,38 @@ +from uuid import UUID + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.schemas.chunk_schema import RetrieveType + + +class KnowledgeRetrievalNodeConfig(BaseNodeConfig): + query: str = Field( + ..., + description="Search query string" + ) + + kb_ids: list[UUID] = Field( + ..., + description="Knowledge base IDs" + ) + + similarity_threshold: float = Field( + default=0.2, + description="Knowledge base similarity threshold" + ) + + vector_similarity_weight: float = Field( + default=0.3, + description="Knowledge base vector similarity weight" + ) + + top_k: int = Field( + default=4, + description="Knowledge base top k" + ) + + retrieve_type: RetrieveType = Field( + default=RetrieveType.PARTICIPLE, + description="Retrieve type" + ) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py new file mode 100644 index 00000000..72e8750f --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -0,0 +1,97 @@ +import logging +from typing import Any + +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.db import get_db +from app.models import knowledge_model, knowledgeshare_model +from app.repositories import knowledge_repository +from app.schemas.chunk_schema import RetrieveType +from app.services import knowledge_service, knowledgeshare_service + +logger = logging.getLogger(__name__) + + +class KnowledgeRetrievalNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + query = self._render_template(self.typed_config.query, state) + db_gen = get_db() + db = next(db_gen) + try: + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + existing_ids = knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + if share_ids: + filters = [ + knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) + ] + items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( + db=db, + filters=filters + ) + existing_ids.extend(items) + + if not existing_ids: + raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") + + kb_id = existing_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + indices = ",".join(uuid_strs) + + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") + + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + match self.typed_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + return [chunk.model_dump() for chunk in rs] + case RetrieveType.SEMANTIC: + rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + return [chunk.model_dump() for chunk in rs] + case _: + rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + # Efficient deduplication + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) + return [chunk.model_dump() for chunk in rs] + finally: + next(db_gen) diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 93364083..2ae31d4d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,7 +7,7 @@ import logging from typing import Any, Union -# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -29,7 +29,7 @@ WorkflowNode = Union[ AgentNode, TransformNode, AssignerNode, - # KnowledgeRetrievalNode, + KnowledgeRetrievalNode, ] @@ -47,7 +47,7 @@ class NodeFactory: NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, NodeType.IF_ELSE: IfElseNode, - # NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.ASSIGNER: AssignerNode, } diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 0acf04b0..7c3a2fca 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -23,7 +23,7 @@ class StartNode(BaseNode): 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 """ - + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): """初始化 Start 节点 @@ -32,10 +32,10 @@ class StartNode(BaseNode): workflow_config: 工作流配置 """ super().__init__(node_config, workflow_config) - + # 解析并验证配置 self.typed_config = StartNodeConfig(**self.config) - + async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行 start 节点业务逻辑 @@ -48,13 +48,13 @@ class StartNode(BaseNode): 包含系统参数、会话变量和自定义变量的字典 """ logger.info(f"节点 {self.node_id} (Start) 开始执行") - + # 创建变量池实例(在方法内复用) pool = self.get_variable_pool(state) - + # 处理自定义变量(传入 pool 避免重复创建) custom_vars = self._process_custom_variables(pool) - + # 返回业务数据(包含自定义变量) result = { "message": pool.get("sys.message"), @@ -64,14 +64,14 @@ class StartNode(BaseNode): "user_id": pool.get("sys.user_id"), **custom_vars # 自定义变量作为节点输出的一部分 } - + logger.info( f"节点 {self.node_id} (Start) 执行完成," f"输出了 {len(custom_vars)} 个自定义变量" ) - + return result - + def _process_custom_variables(self, pool) -> dict[str, Any]: """处理自定义变量 @@ -88,34 +88,33 @@ class StartNode(BaseNode): """ # 获取输入数据中的自定义变量 input_variables = pool.get("sys.input_variables", default={}) - + processed = {} - + # 遍历配置的变量定义 for var_def in self.typed_config.variables: var_name = var_def.name - + # 检查变量是否存在 if var_name in input_variables: # 使用用户提供的值 processed[var_name] = input_variables[var_name] - + elif var_def.required: # 必需变量缺失 raise ValueError( f"缺少必需的输入变量: {var_name}" + (f" ({var_def.description})" if var_def.description else "") ) - + elif var_def.default is not None: # 使用默认值 processed[var_name] = var_def.default logger.debug( f"变量 '{var_name}' 使用默认值: {var_def.default}" ) - + return processed - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取输入数据(用于记录) @@ -127,7 +126,7 @@ class StartNode(BaseNode): 输入数据字典 """ pool = self.get_variable_pool(state) - + return { "execution_id": pool.get("sys.execution_id"), "conversation_id": pool.get("sys.conversation_id"), diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 5d4946fa..b7908cb0 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -52,7 +52,7 @@ def get_knowledges_paginated( raise -def get_chunded_knowledgeids( +def get_chunked_knowledgeids( db: Session, filters: list ) -> list: diff --git a/api/app/services/knowledge_service.py b/api/app/services/knowledge_service.py index b9d97c29..cf47fd4f 100644 --- a/api/app/services/knowledge_service.py +++ b/api/app/services/knowledge_service.py @@ -45,7 +45,7 @@ def get_chunded_knowledgeids( business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}") try: - items = knowledge_repository.get_chunded_knowledgeids( + items = knowledge_repository.get_chunked_knowledgeids( db=db, filters=filters ) From b99671e04a34be1ce8756fad75331f7ee0ef87ad Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:12:11 +0800 Subject: [PATCH 08/13] fix(template): remove default initial model in templates --- api/app/templates/workflows/simple_qa/template.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml index 1b68d55d..ab1af3c2 100644 --- a/api/app/templates/workflows/simple_qa/template.yml +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -44,11 +44,11 @@ nodes: - role: user content: "{{ sys.message }}" - model_id: gpt-3.5-turbo + model_id: null temperature: 0.7 max_tokens: 1000 position: - x: 300 + x: 500 y: 100 - id: end @@ -57,7 +57,7 @@ nodes: config: output: "{{ llm_qa.output }}" position: - x: 500 + x: 900 y: 100 edges: From 38220006a6ec5b39a4d00f67f1d0ae976f6468da Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:21:12 +0800 Subject: [PATCH 09/13] fix(db): fix database connection handling --- api/app/core/workflow/nodes/knowledge/node.py | 124 +++++++++--------- 1 file changed, 60 insertions(+), 64 deletions(-) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 72e8750f..a9a76743 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -20,78 +20,74 @@ class KnowledgeRetrievalNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: query = self._render_template(self.typed_config.query, state) - db_gen = get_db() - db = next(db_gen) - try: + db = next(get_db()) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + existing_ids = knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + if share_ids: filters = [ - knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 + knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) ] - existing_ids = knowledge_repository.get_chunked_knowledgeids( + items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( db=db, filters=filters ) - filters = [ - knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 - ] - share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) - ] - items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters - ) - existing_ids.extend(items) + existing_ids.extend(items) - if not existing_ids: - raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") + if not existing_ids: + raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] - indices = ",".join(uuid_strs) + kb_id = existing_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + indices = ",".join(uuid_strs) - db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) - if not db_knowledge: - raise RuntimeError("The knowledge base does not exist or access is denied.") + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - match self.typed_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - return [chunk.model_dump() for chunk in rs] - case RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + match self.typed_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + return [chunk.model_dump() for chunk in rs] + case RetrieveType.SEMANTIC: + rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + return [chunk.model_dump() for chunk in rs] + case _: + rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - return [chunk.model_dump() for chunk in rs] - case _: - rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - # Efficient deduplication - seen_ids = set() - unique_rs = [] - for doc in rs1 + rs2: - if doc.metadata["doc_id"] not in seen_ids: - seen_ids.add(doc.metadata["doc_id"]) - unique_rs.append(doc) - rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) - return [chunk.model_dump() for chunk in rs] - finally: - next(db_gen) + score_threshold=self.typed_config.similarity_threshold) + # Efficient deduplication + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) + return [chunk.model_dump() for chunk in rs] From 0a8c1be084f7c80226d3ea990741fe5ce1776427 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:22:59 +0800 Subject: [PATCH 10/13] fix(db): fix database connection handling --- api/app/core/workflow/nodes/knowledge/node.py | 122 +++++++++--------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index a9a76743..cb80db5b 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -4,7 +4,7 @@ from typing import Any from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig -from app.db import get_db +from app.db import get_db_context from app.models import knowledge_model, knowledgeshare_model from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType @@ -20,74 +20,74 @@ class KnowledgeRetrievalNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: query = self._render_template(self.typed_config.query, state) - db = next(get_db()) - filters = [ - knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 - ] - existing_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - filters = [ - knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 - ] - share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - if share_ids: + with get_db_context(): filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 ] - items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( + existing_ids = knowledge_repository.get_chunked_knowledgeids( db=db, filters=filters ) - existing_ids.extend(items) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + if share_ids: + filters = [ + knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) + ] + items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( + db=db, + filters=filters + ) + existing_ids.extend(items) - if not existing_ids: - raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") + if not existing_ids: + raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] - indices = ",".join(uuid_strs) + kb_id = existing_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + indices = ",".join(uuid_strs) - db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) - if not db_knowledge: - raise RuntimeError("The knowledge base does not exist or access is denied.") + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - match self.typed_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - return [chunk.model_dump() for chunk in rs] - case RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - return [chunk.model_dump() for chunk in rs] - case _: - rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + match self.typed_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + return [chunk.model_dump() for chunk in rs] + case RetrieveType.SEMANTIC: + rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, indices=indices, - score_threshold=self.typed_config.similarity_threshold) - # Efficient deduplication - seen_ids = set() - unique_rs = [] - for doc in rs1 + rs2: - if doc.metadata["doc_id"] not in seen_ids: - seen_ids.add(doc.metadata["doc_id"]) - unique_rs.append(doc) - rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) - return [chunk.model_dump() for chunk in rs] + score_threshold=self.typed_config.vector_similarity_weight) + return [chunk.model_dump() for chunk in rs] + case _: + rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + # Efficient deduplication + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) + return [chunk.model_dump() for chunk in rs] From 1253bedcde99e32e83fe542df7c4f0aa0f71d0f6 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:23:16 +0800 Subject: [PATCH 11/13] fix(db): fix database connection handling --- api/app/core/workflow/nodes/knowledge/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index cb80db5b..97ebaa82 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -20,7 +20,7 @@ class KnowledgeRetrievalNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: query = self._render_template(self.typed_config.query, state) - with get_db_context(): + with get_db_context() as db: filters = [ knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, From 63d5047d217cc95594c734ef7dde6d972eea2b88 Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 24 Dec 2025 12:37:50 +0800 Subject: [PATCH 12/13] [fix] end stream output --- api/app/core/workflow/nodes/end/node.py | 97 ++++++++++++++----------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 3cece96b..efc62dc5 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -9,28 +9,29 @@ import re import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import NodeType logger = logging.getLogger(__name__) class EndNode(BaseNode): """End 节点 - + 工作流的结束节点,根据配置的模板输出最终结果。 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 """ - + async def execute(self, state: WorkflowState) -> str: """执行 end 节点业务逻辑 - + Args: state: 工作流状态 - + Returns: 最终输出字符串 """ logger.info(f"节点 {self.node_id} (End) 开始执行") - + # 获取配置的输出模板 output_template = self.config.get("output") @@ -39,11 +40,11 @@ class EndNode(BaseNode): output = self._render_template(output_template, state) else: output = "工作流已完成" - + # 统计信息(用于日志) node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - + logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") return output @@ -127,24 +128,26 @@ class EndNode(BaseNode): return parts async def execute_stream(self, state: WorkflowState): - """流式执行 end 节点业务逻辑 + """Execute End node business logic (streaming) - 智能输出策略: - 1. 检测模板中是否引用了直接上游节点 - 2. 如果引用了,只输出该引用**之后**的部分(后缀) - 3. 前缀和引用内容已经在上游节点流式输出时发送了 + Smart output strategy: + 1. Check if template references a direct upstream LLM node + 2. If yes, only output the part AFTER that reference (suffix) + 3. Prefix and LLM content have already been sent during LLM node streaming - 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' - - 直接上游节点是 llm_qa - - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - - LLM 内容在 LLM 节点流式输出 - - End 节点只输出 ' lalalalala a'(后缀,一次性输出) + Note: Only LLM nodes get this special treatment. Other node types output normally. + + Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' + - Direct upstream LLM node is llm_qa + - Prefix '{{start.test}}hahaha ' was sent before LLM node streaming + - LLM content was streamed during LLM node execution + - End node only outputs ' lalalalala a' (suffix, sent as one chunk) Args: - state: 工作流状态 + state: Workflow state Yields: - 完成标记 + Completion marker """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") @@ -156,39 +159,45 @@ class EndNode(BaseNode): yield {"__final__": True, "result": output} return - # 找到直接上游节点 - direct_upstream_nodes = [] + # Find direct upstream LLM nodes + direct_upstream_llm_nodes = [] for edge in self.workflow_config.get("edges", []): if edge.get("target") == self.node_id: source_node_id = edge.get("source") - direct_upstream_nodes.append(source_node_id) + # Check if the source node is an LLM node + for node in self.workflow_config.get("nodes", []): + print("="*50) + logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") + if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: + direct_upstream_llm_nodes.append(source_node_id) + break - logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") + logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}") - # 解析模板部分 + # Parse template parts parts = self._parse_template_parts(output_template, state) logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") for i, part in enumerate(parts): logger.info(f"[模板解析] part[{i}]: {part}") - # 找到第一个引用直接上游节点的动态引用 - upstream_ref_index = None + # Find the first reference to a direct upstream LLM node + upstream_llm_ref_index = None for i, part in enumerate(parts): - if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes: - upstream_ref_index = i - logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") + if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes: + upstream_llm_ref_index = i + logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}") break - if upstream_ref_index is None: - # 没有引用直接上游节点,输出完整模板内容 + if upstream_llm_ref_index is None: + # No reference to direct upstream LLM node, output complete template content output = self._render_template(output_template, state) - logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") + logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'") - # 通过 writer 发送完整内容(作为一个 message chunk) + # Send complete content via writer (as a single message chunk) from langgraph.config import get_stream_writer writer = get_stream_writer() writer({ - "type": "message", # End 节点的输出使用 message 类型 + "type": "message", # End node output uses message type "node_id": self.node_id, "chunk": output, "full_content": output, @@ -197,17 +206,17 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") - # yield 完成标记 + # yield completion marker yield {"__final__": True, "result": output} return - # 有引用直接上游节点,只输出该引用之后的部分(后缀) - logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") + # Has reference to direct upstream LLM node, only output the part after that reference (suffix) + logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") - # 收集后缀部分 + # Collect suffix parts suffix_parts = [] - logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}") - for i in range(upstream_ref_index + 1, len(parts)): + logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}") + for i in range(upstream_llm_ref_index + 1, len(parts)): part = parts[i] logger.info(f"[后缀调试] 处理 part[{i}]: {part}") if part["type"] == "static": @@ -219,7 +228,7 @@ class EndNode(BaseNode): # Other dynamic references (if there are multiple references) node_id = part["node_id"] field = part["field"] - + # Use VariablePool to get variable value pool = self.get_variable_pool(state) try: @@ -232,7 +241,7 @@ class EndNode(BaseNode): # Convert to string if not None suffix_parts.append(str(content) if content is not None else "") - + # 拼接后缀 suffix = "".join(suffix_parts) @@ -261,8 +270,8 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") else: - logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}") - + logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") + # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) From 3afe5475592b6447811f64dc9aaacfd6be2b4398 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Wed, 24 Dec 2025 06:59:22 +0000 Subject: [PATCH 13/13] Merge #37 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 反思输出输入格式统一 * fix/memory_reflection: (60 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py # api/app/schemas/memory_reflection_schemas.py - 反思优化 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 统一输出 - 统一输出 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py - 统一输出 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思速度提升,从4分钟优化成1分10-40秒 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 更新 self_reflexion.py - 反思图谱添加边的修改 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - update # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py # api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 - 反思BUG修复 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 反思BUG修复 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思BUG修复 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 反思输出输入格式统一 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/utils/prompt/template_render.py - 反思优化提示词,提升速度,删除多余LOG日志 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection Signed-off-by: aliyun8644380055 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/37 --- .../reflection_engine/example/example.json | 24 ++++--------- .../reflection_engine/self_reflexion.py | 34 ++----------------- .../utils/prompt/prompts/evaluate.jinja2 | 2 ++ .../utils/prompt/prompts/reflexion.jinja2 | 1 - .../memory/utils/prompt/template_render.py | 29 ++++++++++++---- api/app/schemas/memory_storage_schema.py | 10 +++++- 6 files changed, 42 insertions(+), 58 deletions(-) diff --git a/api/app/core/memory/storage_services/reflection_engine/example/example.json b/api/app/core/memory/storage_services/reflection_engine/example/example.json index fe7a3816..18a2b185 100644 --- a/api/app/core/memory/storage_services/reflection_engine/example/example.json +++ b/api/app/core/memory/storage_services/reflection_engine/example/example.json @@ -50,9 +50,7 @@ "entity2_name": "用户", "entity2": { "description": "叙述者,讲述个人工作与生活经历的个体", - "statement_id": "62beac695b1346f4871740a45db88782", - "name": "用户", - "id": "3d3896797b334572a80d57590026063d" + "name": "用户" } }, { @@ -62,9 +60,7 @@ "entity2_name": "身份信息", "entity2": { "description": "用于个人身份识别的数据", - "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "name": "身份信息", - "id": "aa766a517e82490599a9b3af54cfd933" + "name": "身份信息" } }, { @@ -74,9 +70,7 @@ "entity2_name": "6222023847595898", "entity2": { "description": "用户的银行卡号码", - "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", - "name": "6222023847595898", - "id": "610ba361918f4e68a65ce6ad06e5c7a0" + "name": "6222023847595898" } }, { @@ -88,9 +82,7 @@ "entity_idx": 1, "aliases": ["上海办"], "description": "位于上海的工作办公场所", - "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", - "name": "上海办公室", - "id": "fb702ef695c14e14af3e56786bc8815b" + "name": "上海办公室" } }, { @@ -101,9 +93,7 @@ "entity2": { "aliases": ["京", "京城", "北平"], "description": "中国的首都城市,用户主要工作和生活所在地", - "statement_id": "62beac695b1346f4871740a45db88782", - "name": "北京", - "id": "81b2d1a571bb46a08a2d7a1e87efb945" + "name": "北京" } }, { @@ -113,9 +103,7 @@ "entity2_name": "身份证号", "entity2": { "description": "中华人民共和国公民的身份号码", - "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "name": "身份证号", - "id": "3e5f920645b2404fadb0e9ff60d1306e" + "name": "身份证号" } } ] diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 224a9560..3a4db30d 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -239,8 +239,6 @@ class ReflectionEngine: # # 检查是否真的有冲突 conflicts_found='' - # 记录冲突数据 - await self._log_data("conflict", conflict_data) conflicts_found='' # 3. 解决冲突 solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) @@ -258,8 +256,6 @@ class ReflectionEngine: conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") - # 记录解决方案 - await self._log_data("solved_data", solved_data) # 4. 应用反思结果(更新记忆库) memories_updated=await self._apply_reflection_results(solved_data) @@ -360,14 +356,7 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - - # 检查是否真的有冲突 - has_conflict = conflict_data[0].get('conflict', False) - conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 - logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") - - # 记录冲突数据 - await self._log_data("conflict", conflict_data) + conflicts_found='' # Clearn conflict_data,And memory_verify和quality_assessment cleaned_conflict_data = [] @@ -377,6 +366,7 @@ class ReflectionEngine: 'conflict': item['conflict'] } cleaned_conflict_data.append(cleaned_item) + # 3. 解决冲突 solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) if not solved_data: @@ -615,26 +605,7 @@ class ReflectionEngine: success_count = await neo4j_data(changes) return success_count - async def _log_data(self, label: str, data: Any) -> None: - """ - 记录数据到文件 - Args: - label: 数据标签 - data: 要记录的数据 - """ - - def _write(): - try: - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - except Exception as e: - logging.warning(f"记录数据失败: {e}") - - # 在后台线程中执行写入,避免阻塞事件循环 - await asyncio.to_thread(_write) # 基于时间的反思方法 async def time_based_reflection( @@ -723,4 +694,3 @@ class ReflectionEngine: raise ValueError(f"未知的反思基线: {self.config.baseline}") - diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index b292c804..200f2667 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -17,10 +17,12 @@ - **日期属性冲突**: 同一人的生日等单值属性出现多值 - **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业) - **互斥重叠**: 同一时间出现在不同地点等互斥事件 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 事实冲突 - **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢) - **关系矛盾**: 同一实体在相同语境下的不同关系描述 - **身份冲突**: 同一实体被赋予不同类型或角色 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 混合冲突 检测所有逻辑不一致或相互矛盾的记录。 **检测原则**: diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 36474d91..99476c82 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -171,7 +171,6 @@ ] } ``` - **输出要求**: - 只输出JSON,不添加解释文本 - 使用标准双引号,必要时转义 diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index 46bb64e8..68e0ffe4 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -7,7 +7,7 @@ from typing import List, Dict, Any prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) -async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], +async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, baseline: str = "TIME", memory_verify: bool = False,quality_assessment:bool = False, statement_databasets: List[str] = [],language_type:str = "zh") -> str: @@ -16,7 +16,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any Args: evaluate_data: The data to evaluate - schema: The JSON schema to use for the output. + schema: The Pydantic model class or JSON schema to use for the output. baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT) memory_verify: Whether to enable memory verification for privacy detection @@ -25,9 +25,17 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any """ template = prompt_env.get_template("evaluate.jinja2") + # Convert Pydantic model to JSON schema if needed + if hasattr(schema, 'model_json_schema'): + json_schema = schema.model_json_schema() + elif hasattr(schema, 'schema'): + json_schema = schema.schema() + else: + json_schema = schema + rendered_prompt = template.render( evaluate_data=evaluate_data, - json_schema=schema, + json_schema=json_schema, baseline=baseline, memory_verify=memory_verify, quality_assessment=quality_assessment, @@ -36,14 +44,15 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any ) return rendered_prompt -async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False, + +async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False, statement_databasets: List[str] = [],language_type:str = "zh") -> str: """ Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Args: data: The data to reflex on. - schema: The JSON schema to use for the output. + schema: The Pydantic model class or JSON schema to use for the output. baseline: The baseline type for conflict resolution. Returns: @@ -51,7 +60,15 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], """ template = prompt_env.get_template("reflexion.jinja2") - rendered_prompt = template.render(data=data, json_schema=schema, + # Convert Pydantic model to JSON schema if needed + if hasattr(schema, 'model_json_schema'): + json_schema = schema.model_json_schema() + elif hasattr(schema, 'schema'): + json_schema = schema.schema() + else: + json_schema = schema + + rendered_prompt = template.render(data=data, json_schema=json_schema, baseline=baseline,memory_verify=memory_verify, statement_databasets=statement_databasets,language_type=language_type) diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index df70ec77..33d0d097 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -31,7 +31,7 @@ class BaseDataSchema(BaseModel): # 保持原有必需字段为可选,以兼容不同数据源 id: Optional[str] = Field(None, description="The unique identifier for the data entry.") statement: Optional[str] = Field(None, description="The statement text.") - created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") + created_at: Optional[str] = Field(None, description="The creation timestamp in ISO 8601 format.") expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") description: Optional[str] = Field(None, description="The description of the data entry.") @@ -46,6 +46,14 @@ class BaseDataSchema(BaseModel): relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.") entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.") + @model_validator(mode="before") + def _set_default_created_at(cls, v): + """Set default created_at if missing""" + if isinstance(v, dict) and v.get("created_at") is None: + from datetime import datetime + v["created_at"] = datetime.now().isoformat() + return v + class QualityAssessmentSchema(BaseModel): """Schema for memory quality assessment results."""