diff --git a/api/app/core/memory/agent/mcp_server/server.py b/api/app/core/memory/agent/mcp_server/server.py index f7a23236..26f24824 100644 --- a/api/app/core/memory/agent/mcp_server/server.py +++ b/api/app/core/memory/agent/mcp_server/server.py @@ -131,6 +131,18 @@ def main(): # Get MCP port from environment (default: 8081) mcp_port = int(os.getenv("MCP_PORT", "8081")) + logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") + + # Configure DNS rebinding protection for Docker container compatibility + from mcp.server.fastmcp.server import TransportSecuritySettings + + # Disable DNS rebinding protection to allow Docker container hostnames + # This allows containers to connect using service names like 'mcp-server' + mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ) + logger.info("DNS rebinding protection: disabled for Docker container compatibility") + # logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") # Run the server with SSE transport for HTTP connections diff --git a/api/app/core/memory/storage_services/reflection_engine/example/example.json b/api/app/core/memory/storage_services/reflection_engine/example/example.json index fe7a3816..18a2b185 100644 --- a/api/app/core/memory/storage_services/reflection_engine/example/example.json +++ b/api/app/core/memory/storage_services/reflection_engine/example/example.json @@ -50,9 +50,7 @@ "entity2_name": "用户", "entity2": { "description": "叙述者,讲述个人工作与生活经历的个体", - "statement_id": "62beac695b1346f4871740a45db88782", - "name": "用户", - "id": "3d3896797b334572a80d57590026063d" + "name": "用户" } }, { @@ -62,9 +60,7 @@ "entity2_name": "身份信息", "entity2": { "description": "用于个人身份识别的数据", - "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "name": "身份信息", - "id": "aa766a517e82490599a9b3af54cfd933" + "name": "身份信息" } }, { @@ -74,9 +70,7 @@ "entity2_name": "6222023847595898", "entity2": { "description": "用户的银行卡号码", - "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", - "name": "6222023847595898", - "id": "610ba361918f4e68a65ce6ad06e5c7a0" + "name": "6222023847595898" } }, { @@ -88,9 +82,7 @@ "entity_idx": 1, "aliases": ["上海办"], "description": "位于上海的工作办公场所", - "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", - "name": "上海办公室", - "id": "fb702ef695c14e14af3e56786bc8815b" + "name": "上海办公室" } }, { @@ -101,9 +93,7 @@ "entity2": { "aliases": ["京", "京城", "北平"], "description": "中国的首都城市,用户主要工作和生活所在地", - "statement_id": "62beac695b1346f4871740a45db88782", - "name": "北京", - "id": "81b2d1a571bb46a08a2d7a1e87efb945" + "name": "北京" } }, { @@ -113,9 +103,7 @@ "entity2_name": "身份证号", "entity2": { "description": "中华人民共和国公民的身份号码", - "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "name": "身份证号", - "id": "3e5f920645b2404fadb0e9ff60d1306e" + "name": "身份证号" } } ] diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index ebbb97f7..97f51fb9 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -269,8 +269,6 @@ class ReflectionEngine: # # 检查是否真的有冲突 conflicts_found='' - # 记录冲突数据 - await self._log_data("conflict", conflict_data) conflicts_found='' # 3. 解决冲突 solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) @@ -288,8 +286,6 @@ class ReflectionEngine: conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") - # 记录解决方案 - await self._log_data("solved_data", solved_data) # 4. 应用反思结果(更新记忆库) memories_updated=await self._apply_reflection_results(solved_data) @@ -390,14 +386,7 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - - # 检查是否真的有冲突 - has_conflict = conflict_data[0].get('conflict', False) - conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 - logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") - - # 记录冲突数据 - await self._log_data("conflict", conflict_data) + conflicts_found='' # Clearn conflict_data,And memory_verify和quality_assessment cleaned_conflict_data = [] @@ -407,6 +396,7 @@ class ReflectionEngine: 'conflict': item['conflict'] } cleaned_conflict_data.append(cleaned_item) + # 3. 解决冲突 solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) if not solved_data: @@ -645,26 +635,7 @@ class ReflectionEngine: success_count = await neo4j_data(changes) return success_count - async def _log_data(self, label: str, data: Any) -> None: - """ - 记录数据到文件 - Args: - label: 数据标签 - data: 要记录的数据 - """ - - def _write(): - try: - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - except Exception as e: - logging.warning(f"记录数据失败: {e}") - - # 在后台线程中执行写入,避免阻塞事件循环 - await asyncio.to_thread(_write) # 基于时间的反思方法 async def time_based_reflection( @@ -753,4 +724,3 @@ class ReflectionEngine: raise ValueError(f"未知的反思基线: {self.config.baseline}") - diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index b292c804..200f2667 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -17,10 +17,12 @@ - **日期属性冲突**: 同一人的生日等单值属性出现多值 - **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业) - **互斥重叠**: 同一时间出现在不同地点等互斥事件 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 事实冲突 - **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢) - **关系矛盾**: 同一实体在相同语境下的不同关系描述 - **身份冲突**: 同一实体被赋予不同类型或角色 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 混合冲突 检测所有逻辑不一致或相互矛盾的记录。 **检测原则**: diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 36474d91..99476c82 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -171,7 +171,6 @@ ] } ``` - **输出要求**: - 只输出JSON,不添加解释文本 - 使用标准双引号,必要时转义 diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index 46bb64e8..68e0ffe4 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -7,7 +7,7 @@ from typing import List, Dict, Any prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) -async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], +async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, baseline: str = "TIME", memory_verify: bool = False,quality_assessment:bool = False, statement_databasets: List[str] = [],language_type:str = "zh") -> str: @@ -16,7 +16,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any Args: evaluate_data: The data to evaluate - schema: The JSON schema to use for the output. + schema: The Pydantic model class or JSON schema to use for the output. baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT) memory_verify: Whether to enable memory verification for privacy detection @@ -25,9 +25,17 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any """ template = prompt_env.get_template("evaluate.jinja2") + # Convert Pydantic model to JSON schema if needed + if hasattr(schema, 'model_json_schema'): + json_schema = schema.model_json_schema() + elif hasattr(schema, 'schema'): + json_schema = schema.schema() + else: + json_schema = schema + rendered_prompt = template.render( evaluate_data=evaluate_data, - json_schema=schema, + json_schema=json_schema, baseline=baseline, memory_verify=memory_verify, quality_assessment=quality_assessment, @@ -36,14 +44,15 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any ) return rendered_prompt -async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False, + +async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False, statement_databasets: List[str] = [],language_type:str = "zh") -> str: """ Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Args: data: The data to reflex on. - schema: The JSON schema to use for the output. + schema: The Pydantic model class or JSON schema to use for the output. baseline: The baseline type for conflict resolution. Returns: @@ -51,7 +60,15 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], """ template = prompt_env.get_template("reflexion.jinja2") - rendered_prompt = template.render(data=data, json_schema=schema, + # Convert Pydantic model to JSON schema if needed + if hasattr(schema, 'model_json_schema'): + json_schema = schema.model_json_schema() + elif hasattr(schema, 'schema'): + json_schema = schema.schema() + else: + json_schema = schema + + rendered_prompt = template.render(data=data, json_schema=json_schema, baseline=baseline,memory_verify=memory_verify, statement_databasets=statement_databasets,language_type=language_type) diff --git a/api/app/core/rag/deepdoc/parser/mineru_parser.py b/api/app/core/rag/deepdoc/parser/mineru_parser.py index ec380922..fe6178ec 100644 --- a/api/app/core/rag/deepdoc/parser/mineru_parser.py +++ b/api/app/core/rag/deepdoc/parser/mineru_parser.py @@ -13,7 +13,7 @@ from io import BytesIO from os import PathLike from pathlib import Path from queue import Empty, Queue -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import numpy as np import pdfplumber @@ -439,7 +439,7 @@ class MinerUParser(RAGPdfParser): def parse_pdf( self, filepath: str | PathLike[str], - binary: BytesIO | bytes, + binary: Optional[Union[BytesIO, bytes]] = None, callback: Optional[Callable] = None, *, output_dir: Optional[str] = None, diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 176f996a..198d1473 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector): self._client.indices.create(index=self._collection_name, body=index_mapping) -class ElasticSearchVectorFactory(ABC): - def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector: +class ElasticSearchVectorFactory: + @staticmethod + def init_vector(knowledge: Knowledge) -> ElasticSearchVector: collection_name = f"Vector_index_{knowledge.id}_Node" # Use regular Elasticsearch with config values @@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC): } ) - if knowledge.embedding and knowledge.reranker: - return ElasticSearchVector( - index_name=collection_name, - config=ElasticSearchConfig(**config_dict), - embedding_config=knowledge.embedding.api_keys[0], - reranker_config=knowledge.reranker.api_keys[0] - ) - else: - if knowledge.embedding is None: - raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") - if knowledge.reranker is None: - raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") + if knowledge.embedding is None: + raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") + if knowledge.reranker is None: + raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") + + return ElasticSearchVector( + index_name=collection_name, + config=ElasticSearchConfig(**config_dict), + embedding_config=knowledge.embedding.api_keys[0], + reranker_config=knowledge.reranker.api_keys[0] + ) + diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 1d00532e..fa5a5a2b 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.if_else import IfElseNode -# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode @@ -26,6 +26,6 @@ __all__ = [ "EndNode", "NodeFactory", "WorkflowNode", - # "KnowledgeRetrievalNode", + "KnowledgeRetrievalNode", "AssignerNode", ] diff --git a/api/app/core/workflow/nodes/assigner/config.py b/api/app/core/workflow/nodes/assigner/config.py index 1cb0def3..03302af4 100644 --- a/api/app/core/workflow/nodes/assigner/config.py +++ b/api/app/core/workflow/nodes/assigner/config.py @@ -1,21 +1,32 @@ -from pydantic import Field +from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.enums import AssignmentOperator -class AssignerNodeConfig(BaseNodeConfig): +class AssignmentItem(BaseModel): + """ + Single assignment definition. + """ + variable_selector: str | list[str] = Field( ..., - description="Variables to be assigned", + description="Target variable name(s) to assign", ) operation: AssignmentOperator = Field( ..., - description="Operator to assign", + description="Assignment operator", ) value: str | list[str] = Field( ..., - description="Values to assign", + description="Value(s) to assign to the variable(s)", + ) + + +class AssignerNodeConfig(BaseNodeConfig): + assignments: list[AssignmentItem] = Field( + ..., + description="List of variable assignment definitions", ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index eb32bf8b..b8b7c1f4 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -29,52 +29,52 @@ class AssignerNode(BaseNode): """ # Initialize a variable pool for accessing conversation, node, and system variables pool = VariablePool(state) + for assignment in self.typed_config.assignments: + # Get the target variable selector (e.g., "conv.test") + variable_selector = assignment.variable_selector + if isinstance(variable_selector, str): + # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] + variable_selector = variable_selector.split('.') - # Get the target variable selector (e.g., "conv.test") - variable_selector = self.typed_config.variable_selector - if isinstance(variable_selector, str): - # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] - variable_selector = variable_selector.split('.') + # Only conversation variables ('conv') are allowed + if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) + raise ValueError("Only conversation variables can be assigned.") - # Only conversation variables ('conv') are allowed - if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) - raise ValueError("Only conversation variables can be assigned.") + # Get the value or expression to assign + value = assignment.value + if isinstance(value, list): + value = '.'.join(value) + value = ExpressionEvaluator.evaluate( + expression=value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) - # Get the value or expression to assign - value = self.typed_config.value - if isinstance(value, list): - value = '.'.join(value) - value = ExpressionEvaluator.evaluate( - expression=value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) + # Select the appropriate assignment operator instance based on the target variable type + operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( + pool, variable_selector, value + ) - # Select the appropriate assignment operator instance based on the target variable type - operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( - pool, variable_selector, value - ) - - # Execute the configured assignment operation - match self.typed_config.operation: - case AssignmentOperator.ASSIGN: - operator.assign() - case AssignmentOperator.CLEAR: - operator.clear() - case AssignmentOperator.ADD: - operator.add() - case AssignmentOperator.SUBTRACT: - operator.subtract() - case AssignmentOperator.MULTIPLY: - operator.multiply() - case AssignmentOperator.DIVIDE: - operator.divide() - case AssignmentOperator.APPEND: - operator.append() - case AssignmentOperator.REMOVE_FIRST: - operator.remove_first() - case AssignmentOperator.REMOVE_LAST: - operator.remove_last() - case _: - raise ValueError(f"Invalid Operator: {self.typed_config.operation}") + # Execute the configured assignment operation + match assignment.operation: + case AssignmentOperator.ASSIGN: + operator.assign() + case AssignmentOperator.CLEAR: + operator.clear() + case AssignmentOperator.ADD: + operator.add() + case AssignmentOperator.SUBTRACT: + operator.subtract() + case AssignmentOperator.MULTIPLY: + operator.multiply() + case AssignmentOperator.DIVIDE: + operator.divide() + case AssignmentOperator.APPEND: + operator.append() + case AssignmentOperator.REMOVE_FIRST: + operator.remove_first() + case AssignmentOperator.REMOVE_LAST: + operator.remove_last() + case _: + raise ValueError(f"Invalid Operator: {assignment.operation}") diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index ecded070..e9f102f0 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig -# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig __all__ = [ @@ -30,6 +30,6 @@ __all__ = [ "AgentNodeConfig", "TransformNodeConfig", "IfElseNodeConfig", - # "KnowledgeRetrievalNodeConfig", + "KnowledgeRetrievalNodeConfig", "AssignerNodeConfig", ] diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 3cece96b..efc62dc5 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -9,28 +9,29 @@ import re import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import NodeType logger = logging.getLogger(__name__) class EndNode(BaseNode): """End 节点 - + 工作流的结束节点,根据配置的模板输出最终结果。 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 """ - + async def execute(self, state: WorkflowState) -> str: """执行 end 节点业务逻辑 - + Args: state: 工作流状态 - + Returns: 最终输出字符串 """ logger.info(f"节点 {self.node_id} (End) 开始执行") - + # 获取配置的输出模板 output_template = self.config.get("output") @@ -39,11 +40,11 @@ class EndNode(BaseNode): output = self._render_template(output_template, state) else: output = "工作流已完成" - + # 统计信息(用于日志) node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - + logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") return output @@ -127,24 +128,26 @@ class EndNode(BaseNode): return parts async def execute_stream(self, state: WorkflowState): - """流式执行 end 节点业务逻辑 + """Execute End node business logic (streaming) - 智能输出策略: - 1. 检测模板中是否引用了直接上游节点 - 2. 如果引用了,只输出该引用**之后**的部分(后缀) - 3. 前缀和引用内容已经在上游节点流式输出时发送了 + Smart output strategy: + 1. Check if template references a direct upstream LLM node + 2. If yes, only output the part AFTER that reference (suffix) + 3. Prefix and LLM content have already been sent during LLM node streaming - 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' - - 直接上游节点是 llm_qa - - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - - LLM 内容在 LLM 节点流式输出 - - End 节点只输出 ' lalalalala a'(后缀,一次性输出) + Note: Only LLM nodes get this special treatment. Other node types output normally. + + Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' + - Direct upstream LLM node is llm_qa + - Prefix '{{start.test}}hahaha ' was sent before LLM node streaming + - LLM content was streamed during LLM node execution + - End node only outputs ' lalalalala a' (suffix, sent as one chunk) Args: - state: 工作流状态 + state: Workflow state Yields: - 完成标记 + Completion marker """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") @@ -156,39 +159,45 @@ class EndNode(BaseNode): yield {"__final__": True, "result": output} return - # 找到直接上游节点 - direct_upstream_nodes = [] + # Find direct upstream LLM nodes + direct_upstream_llm_nodes = [] for edge in self.workflow_config.get("edges", []): if edge.get("target") == self.node_id: source_node_id = edge.get("source") - direct_upstream_nodes.append(source_node_id) + # Check if the source node is an LLM node + for node in self.workflow_config.get("nodes", []): + print("="*50) + logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") + if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: + direct_upstream_llm_nodes.append(source_node_id) + break - logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") + logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}") - # 解析模板部分 + # Parse template parts parts = self._parse_template_parts(output_template, state) logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") for i, part in enumerate(parts): logger.info(f"[模板解析] part[{i}]: {part}") - # 找到第一个引用直接上游节点的动态引用 - upstream_ref_index = None + # Find the first reference to a direct upstream LLM node + upstream_llm_ref_index = None for i, part in enumerate(parts): - if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes: - upstream_ref_index = i - logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") + if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes: + upstream_llm_ref_index = i + logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}") break - if upstream_ref_index is None: - # 没有引用直接上游节点,输出完整模板内容 + if upstream_llm_ref_index is None: + # No reference to direct upstream LLM node, output complete template content output = self._render_template(output_template, state) - logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") + logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'") - # 通过 writer 发送完整内容(作为一个 message chunk) + # Send complete content via writer (as a single message chunk) from langgraph.config import get_stream_writer writer = get_stream_writer() writer({ - "type": "message", # End 节点的输出使用 message 类型 + "type": "message", # End node output uses message type "node_id": self.node_id, "chunk": output, "full_content": output, @@ -197,17 +206,17 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") - # yield 完成标记 + # yield completion marker yield {"__final__": True, "result": output} return - # 有引用直接上游节点,只输出该引用之后的部分(后缀) - logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") + # Has reference to direct upstream LLM node, only output the part after that reference (suffix) + logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") - # 收集后缀部分 + # Collect suffix parts suffix_parts = [] - logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}") - for i in range(upstream_ref_index + 1, len(parts)): + logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}") + for i in range(upstream_llm_ref_index + 1, len(parts)): part = parts[i] logger.info(f"[后缀调试] 处理 part[{i}]: {part}") if part["type"] == "static": @@ -219,7 +228,7 @@ class EndNode(BaseNode): # Other dynamic references (if there are multiple references) node_id = part["node_id"] field = part["field"] - + # Use VariablePool to get variable value pool = self.get_variable_pool(state) try: @@ -232,7 +241,7 @@ class EndNode(BaseNode): # Convert to string if not None suffix_parts.append(str(content) if content is not None else "") - + # 拼接后缀 suffix = "".join(suffix_parts) @@ -261,8 +270,8 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") else: - logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}") - + logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") + # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) diff --git a/api/app/core/workflow/nodes/knowledge/__init__.py b/api/app/core/workflow/nodes/knowledge/__init__.py new file mode 100644 index 00000000..25d0f00b --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"] diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py new file mode 100644 index 00000000..530116ff --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -0,0 +1,38 @@ +from uuid import UUID + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.schemas.chunk_schema import RetrieveType + + +class KnowledgeRetrievalNodeConfig(BaseNodeConfig): + query: str = Field( + ..., + description="Search query string" + ) + + kb_ids: list[UUID] = Field( + ..., + description="Knowledge base IDs" + ) + + similarity_threshold: float = Field( + default=0.2, + description="Knowledge base similarity threshold" + ) + + vector_similarity_weight: float = Field( + default=0.3, + description="Knowledge base vector similarity weight" + ) + + top_k: int = Field( + default=4, + description="Knowledge base top k" + ) + + retrieve_type: RetrieveType = Field( + default=RetrieveType.PARTICIPLE, + description="Retrieve type" + ) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py new file mode 100644 index 00000000..97ebaa82 --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -0,0 +1,93 @@ +import logging +from typing import Any + +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.db import get_db_context +from app.models import knowledge_model, knowledgeshare_model +from app.repositories import knowledge_repository +from app.schemas.chunk_schema import RetrieveType +from app.services import knowledge_service, knowledgeshare_service + +logger = logging.getLogger(__name__) + + +class KnowledgeRetrievalNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + query = self._render_template(self.typed_config.query, state) + with get_db_context() as db: + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + existing_ids = knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + if share_ids: + filters = [ + knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) + ] + items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( + db=db, + filters=filters + ) + existing_ids.extend(items) + + if not existing_ids: + raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") + + kb_id = existing_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + indices = ",".join(uuid_strs) + + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") + + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + match self.typed_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + return [chunk.model_dump() for chunk in rs] + case RetrieveType.SEMANTIC: + rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + return [chunk.model_dump() for chunk in rs] + case _: + rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + # Efficient deduplication + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) + return [chunk.model_dump() for chunk in rs] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 93364083..2ae31d4d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,7 +7,7 @@ import logging from typing import Any, Union -# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -29,7 +29,7 @@ WorkflowNode = Union[ AgentNode, TransformNode, AssignerNode, - # KnowledgeRetrievalNode, + KnowledgeRetrievalNode, ] @@ -47,7 +47,7 @@ class NodeFactory: NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, NodeType.IF_ELSE: IfElseNode, - # NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.ASSIGNER: AssignerNode, } diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 0acf04b0..7c3a2fca 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -23,7 +23,7 @@ class StartNode(BaseNode): 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 """ - + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): """初始化 Start 节点 @@ -32,10 +32,10 @@ class StartNode(BaseNode): workflow_config: 工作流配置 """ super().__init__(node_config, workflow_config) - + # 解析并验证配置 self.typed_config = StartNodeConfig(**self.config) - + async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行 start 节点业务逻辑 @@ -48,13 +48,13 @@ class StartNode(BaseNode): 包含系统参数、会话变量和自定义变量的字典 """ logger.info(f"节点 {self.node_id} (Start) 开始执行") - + # 创建变量池实例(在方法内复用) pool = self.get_variable_pool(state) - + # 处理自定义变量(传入 pool 避免重复创建) custom_vars = self._process_custom_variables(pool) - + # 返回业务数据(包含自定义变量) result = { "message": pool.get("sys.message"), @@ -64,14 +64,14 @@ class StartNode(BaseNode): "user_id": pool.get("sys.user_id"), **custom_vars # 自定义变量作为节点输出的一部分 } - + logger.info( f"节点 {self.node_id} (Start) 执行完成," f"输出了 {len(custom_vars)} 个自定义变量" ) - + return result - + def _process_custom_variables(self, pool) -> dict[str, Any]: """处理自定义变量 @@ -88,34 +88,33 @@ class StartNode(BaseNode): """ # 获取输入数据中的自定义变量 input_variables = pool.get("sys.input_variables", default={}) - + processed = {} - + # 遍历配置的变量定义 for var_def in self.typed_config.variables: var_name = var_def.name - + # 检查变量是否存在 if var_name in input_variables: # 使用用户提供的值 processed[var_name] = input_variables[var_name] - + elif var_def.required: # 必需变量缺失 raise ValueError( f"缺少必需的输入变量: {var_name}" + (f" ({var_def.description})" if var_def.description else "") ) - + elif var_def.default is not None: # 使用默认值 processed[var_name] = var_def.default logger.debug( f"变量 '{var_name}' 使用默认值: {var_def.default}" ) - + return processed - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取输入数据(用于记录) @@ -127,7 +126,7 @@ class StartNode(BaseNode): 输入数据字典 """ pool = self.get_variable_pool(state) - + return { "execution_id": pool.get("sys.execution_id"), "conversation_id": pool.get("sys.conversation_id"), diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 5d4946fa..b7908cb0 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -52,7 +52,7 @@ def get_knowledges_paginated( raise -def get_chunded_knowledgeids( +def get_chunked_knowledgeids( db: Session, filters: list ) -> list: diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index df70ec77..33d0d097 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -31,7 +31,7 @@ class BaseDataSchema(BaseModel): # 保持原有必需字段为可选,以兼容不同数据源 id: Optional[str] = Field(None, description="The unique identifier for the data entry.") statement: Optional[str] = Field(None, description="The statement text.") - created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") + created_at: Optional[str] = Field(None, description="The creation timestamp in ISO 8601 format.") expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") description: Optional[str] = Field(None, description="The description of the data entry.") @@ -46,6 +46,14 @@ class BaseDataSchema(BaseModel): relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.") entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.") + @model_validator(mode="before") + def _set_default_created_at(cls, v): + """Set default created_at if missing""" + if isinstance(v, dict) and v.get("created_at") is None: + from datetime import datetime + v["created_at"] = datetime.now().isoformat() + return v + class QualityAssessmentSchema(BaseModel): """Schema for memory quality assessment results.""" diff --git a/api/app/services/knowledge_service.py b/api/app/services/knowledge_service.py index b9d97c29..cf47fd4f 100644 --- a/api/app/services/knowledge_service.py +++ b/api/app/services/knowledge_service.py @@ -45,7 +45,7 @@ def get_chunded_knowledgeids( business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}") try: - items = knowledge_repository.get_chunded_knowledgeids( + items = knowledge_repository.get_chunked_knowledgeids( db=db, filters=filters ) diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml index 1b68d55d..ab1af3c2 100644 --- a/api/app/templates/workflows/simple_qa/template.yml +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -44,11 +44,11 @@ nodes: - role: user content: "{{ sys.message }}" - model_id: gpt-3.5-turbo + model_id: null temperature: 0.7 max_tokens: 1000 position: - x: 300 + x: 500 y: 100 - id: end @@ -57,7 +57,7 @@ nodes: config: output: "{{ llm_qa.output }}" position: - x: 500 + x: 900 y: 100 edges: diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 48ec137d..8470a5d1 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -3,7 +3,7 @@ version: '3.9' services: # MCP Server - standalone service mcp-server: - image: redbear-mem:latest + image: redbear-mem-open:latest container_name: mcp-server ports: - "8081:8081" # MCP server port @@ -28,14 +28,14 @@ services: # FastAPI application - connects to MCP server api: - image: redbear-mem:latest + image: redbear-mem-open:latest container_name: api ports: - "8002:8000" env_file: - .env environment: - - MCP_SERVER_URL=http://mcp-server:8081 + - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces volumes: - ./files:/files @@ -51,12 +51,12 @@ services: # Celery worker - connects to MCP server worker: - image: redbear-mem:latest + image: redbear-mem-open:latest container_name: worker env_file: - .env environment: - - MCP_SERVER_URL=http://mcp-server:8081 + - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro diff --git a/api/env.example b/api/env.example index 1354233d..8ceb3934 100644 --- a/api/env.example +++ b/api/env.example @@ -71,6 +71,18 @@ ENABLE_SINGLE_SESSION= MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024 FILE_PATH=/files +# RAG Setting +DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 +HF_ENDPOINT=https://hf-mirror.com +MINERU_EXECUTABLE=mineru +MINERU_APISERVER=http://host.docker.internal:9987 +MINERU_OUTPUT_DIR=/files +MINERU_BACKEND=pipeline +MINERU_DELETE_OUTPUT=1 +TEXTLN_APISERVER=https://api.textin.com/ai/service/v1/pdf_to_markdown +TEXTLN_APP_ID= +TEXTLN_SECRET_CODE= + # VOLC ASR VOLC_APP_KEY= VOLC_ACCESS_KEY=