Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management

This commit is contained in:
Ke Sun
2025-12-24 16:07:27 +08:00
24 changed files with 365 additions and 202 deletions

View File

@@ -131,6 +131,18 @@ def main():
# Get MCP port from environment (default: 8081) # Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081")) mcp_port = int(os.getenv("MCP_PORT", "8081"))
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Configure DNS rebinding protection for Docker container compatibility
from mcp.server.fastmcp.server import TransportSecuritySettings
# Disable DNS rebinding protection to allow Docker container hostnames
# This allows containers to connect using service names like 'mcp-server'
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") # logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Run the server with SSE transport for HTTP connections # Run the server with SSE transport for HTTP connections

View File

@@ -50,9 +50,7 @@
"entity2_name": "用户", "entity2_name": "用户",
"entity2": { "entity2": {
"description": "叙述者,讲述个人工作与生活经历的个体", "description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782", "name": "用户"
"name": "用户",
"id": "3d3896797b334572a80d57590026063d"
} }
}, },
{ {
@@ -62,9 +60,7 @@
"entity2_name": "身份信息", "entity2_name": "身份信息",
"entity2": { "entity2": {
"description": "用于个人身份识别的数据", "description": "用于个人身份识别的数据",
"statement_id": "030afd362e9b4110b139e68e5d3e7143", "name": "身份信息"
"name": "身份信息",
"id": "aa766a517e82490599a9b3af54cfd933"
} }
}, },
{ {
@@ -74,9 +70,7 @@
"entity2_name": "6222023847595898", "entity2_name": "6222023847595898",
"entity2": { "entity2": {
"description": "用户的银行卡号码", "description": "用户的银行卡号码",
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", "name": "6222023847595898"
"name": "6222023847595898",
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
} }
}, },
{ {
@@ -88,9 +82,7 @@
"entity_idx": 1, "entity_idx": 1,
"aliases": ["上海办"], "aliases": ["上海办"],
"description": "位于上海的工作办公场所", "description": "位于上海的工作办公场所",
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669", "name": "上海办公室"
"name": "上海办公室",
"id": "fb702ef695c14e14af3e56786bc8815b"
} }
}, },
{ {
@@ -101,9 +93,7 @@
"entity2": { "entity2": {
"aliases": ["京", "京城", "北平"], "aliases": ["京", "京城", "北平"],
"description": "中国的首都城市,用户主要工作和生活所在地", "description": "中国的首都城市,用户主要工作和生活所在地",
"statement_id": "62beac695b1346f4871740a45db88782", "name": "北京"
"name": "北京",
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
} }
}, },
{ {
@@ -113,9 +103,7 @@
"entity2_name": "身份证号", "entity2_name": "身份证号",
"entity2": { "entity2": {
"description": "中华人民共和国公民的身份号码", "description": "中华人民共和国公民的身份号码",
"statement_id": "030afd362e9b4110b139e68e5d3e7143", "name": "身份证号"
"name": "身份证号",
"id": "3e5f920645b2404fadb0e9ff60d1306e"
} }
} }
] ]

View File

@@ -269,8 +269,6 @@ class ReflectionEngine:
# # 检查是否真的有冲突 # # 检查是否真的有冲突
conflicts_found='' conflicts_found=''
# 记录冲突数据
await self._log_data("conflict", conflict_data)
conflicts_found='' conflicts_found=''
# 3. 解决冲突 # 3. 解决冲突
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
@@ -288,8 +286,6 @@ class ReflectionEngine:
conflicts_resolved = len(solved_data) conflicts_resolved = len(solved_data)
logging.info(f"解决了 {conflicts_resolved} 个冲突") logging.info(f"解决了 {conflicts_resolved} 个冲突")
# 记录解决方案
await self._log_data("solved_data", solved_data)
# 4. 应用反思结果(更新记忆库) # 4. 应用反思结果(更新记忆库)
memories_updated=await self._apply_reflection_results(solved_data) memories_updated=await self._apply_reflection_results(solved_data)
@@ -390,14 +386,7 @@ class ReflectionEngine:
memory_verifies.append(item['memory_verify']) memory_verifies.append(item['memory_verify'])
result_data['memory_verifies'] = memory_verifies result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments result_data['quality_assessments'] = quality_assessments
conflicts_found=''
# 检查是否真的有冲突
has_conflict = conflict_data[0].get('conflict', False)
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
# 记录冲突数据
await self._log_data("conflict", conflict_data)
# Clearn conflict_dataAnd memory_verify和quality_assessment # Clearn conflict_dataAnd memory_verify和quality_assessment
cleaned_conflict_data = [] cleaned_conflict_data = []
@@ -407,6 +396,7 @@ class ReflectionEngine:
'conflict': item['conflict'] 'conflict': item['conflict']
} }
cleaned_conflict_data.append(cleaned_item) cleaned_conflict_data.append(cleaned_item)
# 3. 解决冲突 # 3. 解决冲突
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
if not solved_data: if not solved_data:
@@ -645,26 +635,7 @@ class ReflectionEngine:
success_count = await neo4j_data(changes) success_count = await neo4j_data(changes)
return success_count return success_count
async def _log_data(self, label: str, data: Any) -> None:
"""
记录数据到文件
Args:
label: 数据标签
data: 要记录的数据
"""
def _write():
try:
with open("reflexion_data.json", "a", encoding="utf-8") as f:
f.write(f"### {label} ###\n")
json.dump(data, f, ensure_ascii=False, indent=4)
f.write("\n\n")
except Exception as e:
logging.warning(f"记录数据失败: {e}")
# 在后台线程中执行写入,避免阻塞事件循环
await asyncio.to_thread(_write)
# 基于时间的反思方法 # 基于时间的反思方法
async def time_based_reflection( async def time_based_reflection(
@@ -753,4 +724,3 @@ class ReflectionEngine:
raise ValueError(f"未知的反思基线: {self.config.baseline}") raise ValueError(f"未知的反思基线: {self.config.baseline}")

View File

@@ -17,10 +17,12 @@
- **日期属性冲突**: 同一人的生日等单值属性出现多值 - **日期属性冲突**: 同一人的生日等单值属性出现多值
- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业) - **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业)
- **互斥重叠**: 同一时间出现在不同地点等互斥事件 - **互斥重叠**: 同一时间出现在不同地点等互斥事件
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
### 事实冲突 ### 事实冲突
- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢) - **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢)
- **关系矛盾**: 同一实体在相同语境下的不同关系描述 - **关系矛盾**: 同一实体在相同语境下的不同关系描述
- **身份冲突**: 同一实体被赋予不同类型或角色 - **身份冲突**: 同一实体被赋予不同类型或角色
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
### 混合冲突 ### 混合冲突
检测所有逻辑不一致或相互矛盾的记录。 检测所有逻辑不一致或相互矛盾的记录。
**检测原则**: **检测原则**:

View File

@@ -171,7 +171,6 @@
] ]
} }
``` ```
**输出要求**: **输出要求**:
- 只输出JSON不添加解释文本 - 只输出JSON不添加解释文本
- 使用标准双引号,必要时转义 - 使用标准双引号,必要时转义

View File

@@ -7,7 +7,7 @@ from typing import List, Dict, Any
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME", baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False, memory_verify: bool = False,quality_assessment:bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str: statement_databasets: List[str] = [],language_type:str = "zh") -> str:
@@ -16,7 +16,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
Args: Args:
evaluate_data: The data to evaluate evaluate_data: The data to evaluate
schema: The JSON schema to use for the output. schema: The Pydantic model class or JSON schema to use for the output.
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT) baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
memory_verify: Whether to enable memory verification for privacy detection memory_verify: Whether to enable memory verification for privacy detection
@@ -25,9 +25,17 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
""" """
template = prompt_env.get_template("evaluate.jinja2") template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed
if hasattr(schema, 'model_json_schema'):
json_schema = schema.model_json_schema()
elif hasattr(schema, 'schema'):
json_schema = schema.schema()
else:
json_schema = schema
rendered_prompt = template.render( rendered_prompt = template.render(
evaluate_data=evaluate_data, evaluate_data=evaluate_data,
json_schema=schema, json_schema=json_schema,
baseline=baseline, baseline=baseline,
memory_verify=memory_verify, memory_verify=memory_verify,
quality_assessment=quality_assessment, quality_assessment=quality_assessment,
@@ -36,14 +44,15 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
) )
return rendered_prompt return rendered_prompt
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str: statement_databasets: List[str] = [],language_type:str = "zh") -> str:
""" """
Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
Args: Args:
data: The data to reflex on. data: The data to reflex on.
schema: The JSON schema to use for the output. schema: The Pydantic model class or JSON schema to use for the output.
baseline: The baseline type for conflict resolution. baseline: The baseline type for conflict resolution.
Returns: Returns:
@@ -51,7 +60,15 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any],
""" """
template = prompt_env.get_template("reflexion.jinja2") template = prompt_env.get_template("reflexion.jinja2")
rendered_prompt = template.render(data=data, json_schema=schema, # Convert Pydantic model to JSON schema if needed
if hasattr(schema, 'model_json_schema'):
json_schema = schema.model_json_schema()
elif hasattr(schema, 'schema'):
json_schema = schema.schema()
else:
json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema,
baseline=baseline,memory_verify=memory_verify, baseline=baseline,memory_verify=memory_verify,
statement_databasets=statement_databasets,language_type=language_type) statement_databasets=statement_databasets,language_type=language_type)

View File

@@ -13,7 +13,7 @@ from io import BytesIO
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
import pdfplumber import pdfplumber
@@ -439,7 +439,7 @@ class MinerUParser(RAGPdfParser):
def parse_pdf( def parse_pdf(
self, self,
filepath: str | PathLike[str], filepath: str | PathLike[str],
binary: BytesIO | bytes, binary: Optional[Union[BytesIO, bytes]] = None,
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
*, *,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,

View File

@@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector):
self._client.indices.create(index=self._collection_name, body=index_mapping) self._client.indices.create(index=self._collection_name, body=index_mapping)
class ElasticSearchVectorFactory(ABC): class ElasticSearchVectorFactory:
def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector: @staticmethod
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
collection_name = f"Vector_index_{knowledge.id}_Node" collection_name = f"Vector_index_{knowledge.id}_Node"
# Use regular Elasticsearch with config values # Use regular Elasticsearch with config values
@@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC):
} }
) )
if knowledge.embedding and knowledge.reranker: if knowledge.embedding is None:
return ElasticSearchVector( raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
index_name=collection_name, if knowledge.reranker is None:
config=ElasticSearchConfig(**config_dict), raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
embedding_config=knowledge.embedding.api_keys[0],
reranker_config=knowledge.reranker.api_keys[0] return ElasticSearchVector(
) index_name=collection_name,
else: config=ElasticSearchConfig(**config_dict),
if knowledge.embedding is None: embedding_config=knowledge.embedding.api_keys[0],
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") reranker_config=knowledge.reranker.api_keys[0]
if knowledge.reranker is None: )
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")

View File

@@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.if_else import IfElseNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
@@ -26,6 +26,6 @@ __all__ = [
"EndNode", "EndNode",
"NodeFactory", "NodeFactory",
"WorkflowNode", "WorkflowNode",
# "KnowledgeRetrievalNode", "KnowledgeRetrievalNode",
"AssignerNode", "AssignerNode",
] ]

View File

@@ -1,21 +1,32 @@
from pydantic import Field from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import AssignmentOperator from app.core.workflow.nodes.enums import AssignmentOperator
class AssignerNodeConfig(BaseNodeConfig): class AssignmentItem(BaseModel):
"""
Single assignment definition.
"""
variable_selector: str | list[str] = Field( variable_selector: str | list[str] = Field(
..., ...,
description="Variables to be assigned", description="Target variable name(s) to assign",
) )
operation: AssignmentOperator = Field( operation: AssignmentOperator = Field(
..., ...,
description="Operator to assign", description="Assignment operator",
) )
value: str | list[str] = Field( value: str | list[str] = Field(
..., ...,
description="Values to assign", description="Value(s) to assign to the variable(s)",
)
class AssignerNodeConfig(BaseNodeConfig):
assignments: list[AssignmentItem] = Field(
...,
description="List of variable assignment definitions",
) )

View File

@@ -29,52 +29,52 @@ class AssignerNode(BaseNode):
""" """
# Initialize a variable pool for accessing conversation, node, and system variables # Initialize a variable pool for accessing conversation, node, and system variables
pool = VariablePool(state) pool = VariablePool(state)
for assignment in self.typed_config.assignments:
# Get the target variable selector (e.g., "conv.test")
variable_selector = assignment.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Get the target variable selector (e.g., "conv.test") # Only conversation variables ('conv') are allowed
variable_selector = self.typed_config.variable_selector if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
if isinstance(variable_selector, str): raise ValueError("Only conversation variables can be assigned.")
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Only conversation variables ('conv') are allowed # Get the value or expression to assign
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) value = assignment.value
raise ValueError("Only conversation variables can be assigned.") if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Get the value or expression to assign # Select the appropriate assignment operator instance based on the target variable type
value = self.typed_config.value operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
if isinstance(value, list): pool, variable_selector, value
value = '.'.join(value) )
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Select the appropriate assignment operator instance based on the target variable type # Execute the configured assignment operation
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( match assignment.operation:
pool, variable_selector, value case AssignmentOperator.ASSIGN:
) operator.assign()
case AssignmentOperator.CLEAR:
# Execute the configured assignment operation operator.clear()
match self.typed_config.operation: case AssignmentOperator.ADD:
case AssignmentOperator.ASSIGN: operator.add()
operator.assign() case AssignmentOperator.SUBTRACT:
case AssignmentOperator.CLEAR: operator.subtract()
operator.clear() case AssignmentOperator.MULTIPLY:
case AssignmentOperator.ADD: operator.multiply()
operator.add() case AssignmentOperator.DIVIDE:
case AssignmentOperator.SUBTRACT: operator.divide()
operator.subtract() case AssignmentOperator.APPEND:
case AssignmentOperator.MULTIPLY: operator.append()
operator.multiply() case AssignmentOperator.REMOVE_FIRST:
case AssignmentOperator.DIVIDE: operator.remove_first()
operator.divide() case AssignmentOperator.REMOVE_LAST:
case AssignmentOperator.APPEND: operator.remove_last()
operator.append() case _:
case AssignmentOperator.REMOVE_FIRST: raise ValueError(f"Invalid Operator: {assignment.operation}")
operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")

View File

@@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [ __all__ = [
@@ -30,6 +30,6 @@ __all__ = [
"AgentNodeConfig", "AgentNodeConfig",
"TransformNodeConfig", "TransformNodeConfig",
"IfElseNodeConfig", "IfElseNodeConfig",
# "KnowledgeRetrievalNodeConfig", "KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig", "AssignerNodeConfig",
] ]

View File

@@ -9,6 +9,7 @@ import re
import asyncio import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -127,24 +128,26 @@ class EndNode(BaseNode):
return parts return parts
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑 """Execute End node business logic (streaming)
智能输出策略: Smart output strategy:
1. 检测模板中是否引用了直接上游节点 1. Check if template references a direct upstream LLM node
2. 如果引用了,只输出该引用**之后**的部分(后缀) 2. If yes, only output the part AFTER that reference (suffix)
3. 前缀和引用内容已经在上游节点流式输出时发送了 3. Prefix and LLM content have already been sent during LLM node streaming
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' Note: Only LLM nodes get this special treatment. Other node types output normally.
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- LLM 内容在 LLM 节点流式输出 - Direct upstream LLM node is llm_qa
- End 节点只输出 ' lalalalala a'(后缀,一次性输出) - Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
- LLM content was streamed during LLM node execution
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
Args: Args:
state: 工作流状态 state: Workflow state
Yields: Yields:
完成标记 Completion marker
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
@@ -156,39 +159,45 @@ class EndNode(BaseNode):
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 找到直接上游节点 # Find direct upstream LLM nodes
direct_upstream_nodes = [] direct_upstream_llm_nodes = []
for edge in self.workflow_config.get("edges", []): for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id: if edge.get("target") == self.node_id:
source_node_id = edge.get("source") source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id) # Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
break
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
# 解析模板部分 # Parse template parts
parts = self._parse_template_parts(output_template, state) parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts): for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}") logger.info(f"[模板解析] part[{i}]: {part}")
# 找到第一个引用直接上游节点的动态引用 # Find the first reference to a direct upstream LLM node
upstream_ref_index = None upstream_llm_ref_index = None
for i, part in enumerate(parts): for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes: if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
upstream_ref_index = i upstream_llm_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
break break
if upstream_ref_index is None: if upstream_llm_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容 # No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk # Send complete content via writer (as a single message chunk)
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
writer = get_stream_writer() writer = get_stream_writer()
writer({ writer({
"type": "message", # End 节点的输出使用 message 类型 "type": "message", # End node output uses message type
"node_id": self.node_id, "node_id": self.node_id,
"chunk": output, "chunk": output,
"full_content": output, "full_content": output,
@@ -197,17 +206,17 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记 # yield completion marker
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 有引用直接上游节点,只输出该引用之后的部分(后缀) # Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# 收集后缀部分 # Collect suffix parts
suffix_parts = [] suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}") logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_ref_index + 1, len(parts)): for i in range(upstream_llm_ref_index + 1, len(parts)):
part = parts[i] part = parts[i]
logger.info(f"[后缀调试] 处理 part[{i}]: {part}") logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
if part["type"] == "static": if part["type"] == "static":
@@ -261,7 +270,7 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}") logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else: else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}") logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息 # 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"]

View File

@@ -0,0 +1,38 @@
from uuid import UUID
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.schemas.chunk_schema import RetrieveType
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
kb_ids: list[UUID] = Field(
...,
description="Knowledge base IDs"
)
similarity_threshold: float = Field(
default=0.2,
description="Knowledge base similarity threshold"
)
vector_similarity_weight: float = Field(
default=0.3,
description="Knowledge base vector similarity weight"
)
top_k: int = Field(
default=4,
description="Knowledge base top k"
)
retrieve_type: RetrieveType = Field(
default=RetrieveType.PARTICIPLE,
description="Retrieve type"
)

View File

@@ -0,0 +1,93 @@
import logging
from typing import Any
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_context
from app.models import knowledge_model, knowledgeshare_model
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service
logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
query = self._render_template(self.typed_config.query, state)
with get_db_context() as db:
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match self.typed_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs]
case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
return [chunk.model_dump() for chunk in rs]
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]

View File

@@ -7,7 +7,7 @@
import logging import logging
from typing import Any, Union from typing import Any, Union
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
@@ -29,7 +29,7 @@ WorkflowNode = Union[
AgentNode, AgentNode,
TransformNode, TransformNode,
AssignerNode, AssignerNode,
# KnowledgeRetrievalNode, KnowledgeRetrievalNode,
] ]
@@ -47,7 +47,7 @@ class NodeFactory:
NodeType.AGENT: AgentNode, NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode, NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode, NodeType.IF_ELSE: IfElseNode,
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode, NodeType.ASSIGNER: AssignerNode,
} }

View File

@@ -116,7 +116,6 @@ class StartNode(BaseNode):
return processed return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]: def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录) """提取输入数据(用于记录)

View File

@@ -52,7 +52,7 @@ def get_knowledges_paginated(
raise raise
def get_chunded_knowledgeids( def get_chunked_knowledgeids(
db: Session, db: Session,
filters: list filters: list
) -> list: ) -> list:

View File

@@ -31,7 +31,7 @@ class BaseDataSchema(BaseModel):
# 保持原有必需字段为可选,以兼容不同数据源 # 保持原有必需字段为可选,以兼容不同数据源
id: Optional[str] = Field(None, description="The unique identifier for the data entry.") id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
statement: Optional[str] = Field(None, description="The statement text.") statement: Optional[str] = Field(None, description="The statement text.")
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") created_at: Optional[str] = Field(None, description="The creation timestamp in ISO 8601 format.")
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
description: Optional[str] = Field(None, description="The description of the data entry.") description: Optional[str] = Field(None, description="The description of the data entry.")
@@ -46,6 +46,14 @@ class BaseDataSchema(BaseModel):
relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.") relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.")
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.") entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
@model_validator(mode="before")
def _set_default_created_at(cls, v):
"""Set default created_at if missing"""
if isinstance(v, dict) and v.get("created_at") is None:
from datetime import datetime
v["created_at"] = datetime.now().isoformat()
return v
class QualityAssessmentSchema(BaseModel): class QualityAssessmentSchema(BaseModel):
"""Schema for memory quality assessment results.""" """Schema for memory quality assessment results."""

View File

@@ -45,7 +45,7 @@ def get_chunded_knowledgeids(
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}") business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
try: try:
items = knowledge_repository.get_chunded_knowledgeids( items = knowledge_repository.get_chunked_knowledgeids(
db=db, db=db,
filters=filters filters=filters
) )

View File

@@ -44,11 +44,11 @@ nodes:
- role: user - role: user
content: "{{ sys.message }}" content: "{{ sys.message }}"
model_id: gpt-3.5-turbo model_id: null
temperature: 0.7 temperature: 0.7
max_tokens: 1000 max_tokens: 1000
position: position:
x: 300 x: 500
y: 100 y: 100
- id: end - id: end
@@ -57,7 +57,7 @@ nodes:
config: config:
output: "{{ llm_qa.output }}" output: "{{ llm_qa.output }}"
position: position:
x: 500 x: 900
y: 100 y: 100
edges: edges:

View File

@@ -3,7 +3,7 @@ version: '3.9'
services: services:
# MCP Server - standalone service # MCP Server - standalone service
mcp-server: mcp-server:
image: redbear-mem:latest image: redbear-mem-open:latest
container_name: mcp-server container_name: mcp-server
ports: ports:
- "8081:8081" # MCP server port - "8081:8081" # MCP server port
@@ -28,14 +28,14 @@ services:
# FastAPI application - connects to MCP server # FastAPI application - connects to MCP server
api: api:
image: redbear-mem:latest image: redbear-mem-open:latest
container_name: api container_name: api
ports: ports:
- "8002:8000" - "8002:8000"
env_file: env_file:
- .env - .env
environment: environment:
- MCP_SERVER_URL=http://mcp-server:8081 - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
- SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces
volumes: volumes:
- ./files:/files - ./files:/files
@@ -51,12 +51,12 @@ services:
# Celery worker - connects to MCP server # Celery worker - connects to MCP server
worker: worker:
image: redbear-mem:latest image: redbear-mem-open:latest
container_name: worker container_name: worker
env_file: env_file:
- .env - .env
environment: environment:
- MCP_SERVER_URL=http://mcp-server:8081 - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
volumes: volumes:
- ./files:/files - ./files:/files
- /etc/localtime:/etc/localtime:ro - /etc/localtime:/etc/localtime:ro

View File

@@ -71,6 +71,18 @@ ENABLE_SINGLE_SESSION=
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024 MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
FILE_PATH=/files FILE_PATH=/files
# RAG Setting
DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
HF_ENDPOINT=https://hf-mirror.com
MINERU_EXECUTABLE=mineru
MINERU_APISERVER=http://host.docker.internal:9987
MINERU_OUTPUT_DIR=/files
MINERU_BACKEND=pipeline
MINERU_DELETE_OUTPUT=1
TEXTLN_APISERVER=https://api.textin.com/ai/service/v1/pdf_to_markdown
TEXTLN_APP_ID=
TEXTLN_SECRET_CODE=
# VOLC ASR # VOLC ASR
VOLC_APP_KEY= VOLC_APP_KEY=
VOLC_ACCESS_KEY= VOLC_ACCESS_KEY=