Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management

This commit is contained in:
Ke Sun
2025-12-24 16:07:27 +08:00
24 changed files with 365 additions and 202 deletions

View File

@@ -131,6 +131,18 @@ def main():
# Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081"))
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Configure DNS rebinding protection for Docker container compatibility
from mcp.server.fastmcp.server import TransportSecuritySettings
# Disable DNS rebinding protection to allow Docker container hostnames
# This allows containers to connect using service names like 'mcp-server'
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Run the server with SSE transport for HTTP connections

View File

@@ -50,9 +50,7 @@
"entity2_name": "用户",
"entity2": {
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"name": "用户",
"id": "3d3896797b334572a80d57590026063d"
"name": "用户"
}
},
{
@@ -62,9 +60,7 @@
"entity2_name": "身份信息",
"entity2": {
"description": "用于个人身份识别的数据",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"name": "身份信息",
"id": "aa766a517e82490599a9b3af54cfd933"
"name": "身份信息"
}
},
{
@@ -74,9 +70,7 @@
"entity2_name": "6222023847595898",
"entity2": {
"description": "用户的银行卡号码",
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
"name": "6222023847595898",
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
"name": "6222023847595898"
}
},
{
@@ -88,9 +82,7 @@
"entity_idx": 1,
"aliases": ["上海办"],
"description": "位于上海的工作办公场所",
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
"name": "上海办公室",
"id": "fb702ef695c14e14af3e56786bc8815b"
"name": "上海办公室"
}
},
{
@@ -101,9 +93,7 @@
"entity2": {
"aliases": ["京", "京城", "北平"],
"description": "中国的首都城市,用户主要工作和生活所在地",
"statement_id": "62beac695b1346f4871740a45db88782",
"name": "北京",
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
"name": "北京"
}
},
{
@@ -113,9 +103,7 @@
"entity2_name": "身份证号",
"entity2": {
"description": "中华人民共和国公民的身份号码",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"name": "身份证号",
"id": "3e5f920645b2404fadb0e9ff60d1306e"
"name": "身份证号"
}
}
]

View File

@@ -269,8 +269,6 @@ class ReflectionEngine:
# # 检查是否真的有冲突
conflicts_found=''
# 记录冲突数据
await self._log_data("conflict", conflict_data)
conflicts_found=''
# 3. 解决冲突
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
@@ -288,8 +286,6 @@ class ReflectionEngine:
conflicts_resolved = len(solved_data)
logging.info(f"解决了 {conflicts_resolved} 个冲突")
# 记录解决方案
await self._log_data("solved_data", solved_data)
# 4. 应用反思结果(更新记忆库)
memories_updated=await self._apply_reflection_results(solved_data)
@@ -390,14 +386,7 @@ class ReflectionEngine:
memory_verifies.append(item['memory_verify'])
result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments
# 检查是否真的有冲突
has_conflict = conflict_data[0].get('conflict', False)
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
# 记录冲突数据
await self._log_data("conflict", conflict_data)
conflicts_found=''
# Clearn conflict_dataAnd memory_verify和quality_assessment
cleaned_conflict_data = []
@@ -407,6 +396,7 @@ class ReflectionEngine:
'conflict': item['conflict']
}
cleaned_conflict_data.append(cleaned_item)
# 3. 解决冲突
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
if not solved_data:
@@ -645,26 +635,7 @@ class ReflectionEngine:
success_count = await neo4j_data(changes)
return success_count
async def _log_data(self, label: str, data: Any) -> None:
"""
记录数据到文件
Args:
label: 数据标签
data: 要记录的数据
"""
def _write():
try:
with open("reflexion_data.json", "a", encoding="utf-8") as f:
f.write(f"### {label} ###\n")
json.dump(data, f, ensure_ascii=False, indent=4)
f.write("\n\n")
except Exception as e:
logging.warning(f"记录数据失败: {e}")
# 在后台线程中执行写入,避免阻塞事件循环
await asyncio.to_thread(_write)
# 基于时间的反思方法
async def time_based_reflection(
@@ -753,4 +724,3 @@ class ReflectionEngine:
raise ValueError(f"未知的反思基线: {self.config.baseline}")

View File

@@ -17,10 +17,12 @@
- **日期属性冲突**: 同一人的生日等单值属性出现多值
- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业)
- **互斥重叠**: 同一时间出现在不同地点等互斥事件
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
### 事实冲突
- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢)
- **关系矛盾**: 同一实体在相同语境下的不同关系描述
- **身份冲突**: 同一实体被赋予不同类型或角色
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
### 混合冲突
检测所有逻辑不一致或相互矛盾的记录。
**检测原则**:

View File

@@ -171,7 +171,6 @@
]
}
```
**输出要求**:
- 只输出JSON不添加解释文本
- 使用标准双引号,必要时转义

View File

@@ -7,7 +7,7 @@ from typing import List, Dict, Any
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
@@ -16,7 +16,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
Args:
evaluate_data: The data to evaluate
schema: The JSON schema to use for the output.
schema: The Pydantic model class or JSON schema to use for the output.
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
memory_verify: Whether to enable memory verification for privacy detection
@@ -25,9 +25,17 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
"""
template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed
if hasattr(schema, 'model_json_schema'):
json_schema = schema.model_json_schema()
elif hasattr(schema, 'schema'):
json_schema = schema.schema()
else:
json_schema = schema
rendered_prompt = template.render(
evaluate_data=evaluate_data,
json_schema=schema,
json_schema=json_schema,
baseline=baseline,
memory_verify=memory_verify,
quality_assessment=quality_assessment,
@@ -36,14 +44,15 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
)
return rendered_prompt
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
"""
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
Args:
data: The data to reflex on.
schema: The JSON schema to use for the output.
schema: The Pydantic model class or JSON schema to use for the output.
baseline: The baseline type for conflict resolution.
Returns:
@@ -51,7 +60,15 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any],
"""
template = prompt_env.get_template("reflexion.jinja2")
rendered_prompt = template.render(data=data, json_schema=schema,
# Convert Pydantic model to JSON schema if needed
if hasattr(schema, 'model_json_schema'):
json_schema = schema.model_json_schema()
elif hasattr(schema, 'schema'):
json_schema = schema.schema()
else:
json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema,
baseline=baseline,memory_verify=memory_verify,
statement_databasets=statement_databasets,language_type=language_type)

View File

@@ -13,7 +13,7 @@ from io import BytesIO
from os import PathLike
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import numpy as np
import pdfplumber
@@ -439,7 +439,7 @@ class MinerUParser(RAGPdfParser):
def parse_pdf(
self,
filepath: str | PathLike[str],
binary: BytesIO | bytes,
binary: Optional[Union[BytesIO, bytes]] = None,
callback: Optional[Callable] = None,
*,
output_dir: Optional[str] = None,

View File

@@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector):
self._client.indices.create(index=self._collection_name, body=index_mapping)
class ElasticSearchVectorFactory(ABC):
def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector:
class ElasticSearchVectorFactory:
@staticmethod
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
collection_name = f"Vector_index_{knowledge.id}_Node"
# Use regular Elasticsearch with config values
@@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC):
}
)
if knowledge.embedding and knowledge.reranker:
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(**config_dict),
embedding_config=knowledge.embedding.api_keys[0],
reranker_config=knowledge.reranker.api_keys[0]
)
else:
if knowledge.embedding is None:
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
if knowledge.reranker is None:
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
if knowledge.embedding is None:
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
if knowledge.reranker is None:
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(**config_dict),
embedding_config=knowledge.embedding.api_keys[0],
reranker_config=knowledge.reranker.api_keys[0]
)

View File

@@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
@@ -26,6 +26,6 @@ __all__ = [
"EndNode",
"NodeFactory",
"WorkflowNode",
# "KnowledgeRetrievalNode",
"KnowledgeRetrievalNode",
"AssignerNode",
]

View File

@@ -1,21 +1,32 @@
from pydantic import Field
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import AssignmentOperator
class AssignerNodeConfig(BaseNodeConfig):
class AssignmentItem(BaseModel):
"""
Single assignment definition.
"""
variable_selector: str | list[str] = Field(
...,
description="Variables to be assigned",
description="Target variable name(s) to assign",
)
operation: AssignmentOperator = Field(
...,
description="Operator to assign",
description="Assignment operator",
)
value: str | list[str] = Field(
...,
description="Values to assign",
description="Value(s) to assign to the variable(s)",
)
class AssignerNodeConfig(BaseNodeConfig):
assignments: list[AssignmentItem] = Field(
...,
description="List of variable assignment definitions",
)

View File

@@ -29,52 +29,52 @@ class AssignerNode(BaseNode):
"""
# Initialize a variable pool for accessing conversation, node, and system variables
pool = VariablePool(state)
for assignment in self.typed_config.assignments:
# Get the target variable selector (e.g., "conv.test")
variable_selector = assignment.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Get the target variable selector (e.g., "conv.test")
variable_selector = self.typed_config.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
# Get the value or expression to assign
value = assignment.value
if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Get the value or expression to assign
value = self.typed_config.value
if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
pool, variable_selector, value
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
pool, variable_selector, value
)
# Execute the configured assignment operation
match self.typed_config.operation:
case AssignmentOperator.ASSIGN:
operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
case AssignmentOperator.ADD:
operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
case AssignmentOperator.APPEND:
operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")
# Execute the configured assignment operation
match assignment.operation:
case AssignmentOperator.ASSIGN:
operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
case AssignmentOperator.ADD:
operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
case AssignmentOperator.APPEND:
operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")

View File

@@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [
@@ -30,6 +30,6 @@ __all__ = [
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
# "KnowledgeRetrievalNodeConfig",
"KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",
]

View File

@@ -9,28 +9,29 @@ import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
Returns:
最终输出字符串
"""
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板
output_template = self.config.get("output")
@@ -39,11 +40,11 @@ class EndNode(BaseNode):
output = self._render_template(output_template, state)
else:
output = "工作流已完成"
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output
@@ -127,24 +128,26 @@ class EndNode(BaseNode):
return parts
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
"""Execute End node business logic (streaming)
智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
Smart output strategy:
1. Check if template references a direct upstream LLM node
2. If yes, only output the part AFTER that reference (suffix)
3. Prefix and LLM content have already been sent during LLM node streaming
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Note: Only LLM nodes get this special treatment. Other node types output normally.
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- Direct upstream LLM node is llm_qa
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
- LLM content was streamed during LLM node execution
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
Args:
state: 工作流状态
state: Workflow state
Yields:
完成标记
Completion marker
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
@@ -156,39 +159,45 @@ class EndNode(BaseNode):
yield {"__final__": True, "result": output}
return
# 找到直接上游节点
direct_upstream_nodes = []
# Find direct upstream LLM nodes
direct_upstream_llm_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
break
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
# 解析模板部分
# Parse template parts
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
# Find the first reference to a direct upstream LLM node
upstream_llm_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
upstream_llm_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容
if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk
# Send complete content via writer (as a single message chunk)
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": output,
"full_content": output,
@@ -197,17 +206,17 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记
# yield completion marker
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# 收集后缀部分
# Collect suffix parts
suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_ref_index + 1, len(parts)):
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_llm_ref_index + 1, len(parts)):
part = parts[i]
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
if part["type"] == "static":
@@ -219,7 +228,7 @@ class EndNode(BaseNode):
# Other dynamic references (if there are multiple references)
node_id = part["node_id"]
field = part["field"]
# Use VariablePool to get variable value
pool = self.get_variable_pool(state)
try:
@@ -232,7 +241,7 @@ class EndNode(BaseNode):
# Convert to string if not None
suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀
suffix = "".join(suffix_parts)
@@ -261,8 +270,8 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}")
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"]

View File

@@ -0,0 +1,38 @@
from uuid import UUID
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.schemas.chunk_schema import RetrieveType
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
kb_ids: list[UUID] = Field(
...,
description="Knowledge base IDs"
)
similarity_threshold: float = Field(
default=0.2,
description="Knowledge base similarity threshold"
)
vector_similarity_weight: float = Field(
default=0.3,
description="Knowledge base vector similarity weight"
)
top_k: int = Field(
default=4,
description="Knowledge base top k"
)
retrieve_type: RetrieveType = Field(
default=RetrieveType.PARTICIPLE,
description="Retrieve type"
)

View File

@@ -0,0 +1,93 @@
import logging
from typing import Any
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_context
from app.models import knowledge_model, knowledgeshare_model
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service
logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
query = self._render_template(self.typed_config.query, state)
with get_db_context() as db:
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match self.typed_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs]
case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
return [chunk.model_dump() for chunk in rs]
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]

View File

@@ -7,7 +7,7 @@
import logging
from typing import Any, Union
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
@@ -29,7 +29,7 @@ WorkflowNode = Union[
AgentNode,
TransformNode,
AssignerNode,
# KnowledgeRetrievalNode,
KnowledgeRetrievalNode,
]
@@ -47,7 +47,7 @@ class NodeFactory:
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode,
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
}

View File

@@ -23,7 +23,7 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
@@ -32,10 +32,10 @@ class StartNode(BaseNode):
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
@@ -48,13 +48,13 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典
"""
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
@@ -64,14 +64,14 @@ class StartNode(BaseNode):
"user_id": pool.get("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
logger.info(
f"节点 {self.node_id} (Start) 执行完成,"
f"输出了 {len(custom_vars)} 个自定义变量"
)
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
"""处理自定义变量
@@ -88,34 +88,33 @@ class StartNode(BaseNode):
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
# 检查变量是否存在
if var_name in input_variables:
# 使用用户提供的值
processed[var_name] = input_variables[var_name]
elif var_def.required:
# 必需变量缺失
raise ValueError(
f"缺少必需的输入变量: {var_name}"
+ (f" ({var_def.description})" if var_def.description else "")
)
elif var_def.default is not None:
# 使用默认值
processed[var_name] = var_def.default
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)
@@ -127,7 +126,7 @@ class StartNode(BaseNode):
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),

View File

@@ -52,7 +52,7 @@ def get_knowledges_paginated(
raise
def get_chunded_knowledgeids(
def get_chunked_knowledgeids(
db: Session,
filters: list
) -> list:

View File

@@ -31,7 +31,7 @@ class BaseDataSchema(BaseModel):
# 保持原有必需字段为可选,以兼容不同数据源
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
statement: Optional[str] = Field(None, description="The statement text.")
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
created_at: Optional[str] = Field(None, description="The creation timestamp in ISO 8601 format.")
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
description: Optional[str] = Field(None, description="The description of the data entry.")
@@ -46,6 +46,14 @@ class BaseDataSchema(BaseModel):
relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.")
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
@model_validator(mode="before")
def _set_default_created_at(cls, v):
"""Set default created_at if missing"""
if isinstance(v, dict) and v.get("created_at") is None:
from datetime import datetime
v["created_at"] = datetime.now().isoformat()
return v
class QualityAssessmentSchema(BaseModel):
"""Schema for memory quality assessment results."""

View File

@@ -45,7 +45,7 @@ def get_chunded_knowledgeids(
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
try:
items = knowledge_repository.get_chunded_knowledgeids(
items = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)

View File

@@ -44,11 +44,11 @@ nodes:
- role: user
content: "{{ sys.message }}"
model_id: gpt-3.5-turbo
model_id: null
temperature: 0.7
max_tokens: 1000
position:
x: 300
x: 500
y: 100
- id: end
@@ -57,7 +57,7 @@ nodes:
config:
output: "{{ llm_qa.output }}"
position:
x: 500
x: 900
y: 100
edges:

View File

@@ -3,7 +3,7 @@ version: '3.9'
services:
# MCP Server - standalone service
mcp-server:
image: redbear-mem:latest
image: redbear-mem-open:latest
container_name: mcp-server
ports:
- "8081:8081" # MCP server port
@@ -28,14 +28,14 @@ services:
# FastAPI application - connects to MCP server
api:
image: redbear-mem:latest
image: redbear-mem-open:latest
container_name: api
ports:
- "8002:8000"
env_file:
- .env
environment:
- MCP_SERVER_URL=http://mcp-server:8081
- MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
- SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces
volumes:
- ./files:/files
@@ -51,12 +51,12 @@ services:
# Celery worker - connects to MCP server
worker:
image: redbear-mem:latest
image: redbear-mem-open:latest
container_name: worker
env_file:
- .env
environment:
- MCP_SERVER_URL=http://mcp-server:8081
- MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro

View File

@@ -71,6 +71,18 @@ ENABLE_SINGLE_SESSION=
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
FILE_PATH=/files
# RAG Setting
DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
HF_ENDPOINT=https://hf-mirror.com
MINERU_EXECUTABLE=mineru
MINERU_APISERVER=http://host.docker.internal:9987
MINERU_OUTPUT_DIR=/files
MINERU_BACKEND=pipeline
MINERU_DELETE_OUTPUT=1
TEXTLN_APISERVER=https://api.textin.com/ai/service/v1/pdf_to_markdown
TEXTLN_APP_ID=
TEXTLN_SECRET_CODE=
# VOLC ASR
VOLC_APP_KEY=
VOLC_ACCESS_KEY=