Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -131,6 +131,18 @@ def main():
|
|||||||
|
|
||||||
# Get MCP port from environment (default: 8081)
|
# Get MCP port from environment (default: 8081)
|
||||||
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
||||||
|
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||||
|
|
||||||
|
# Configure DNS rebinding protection for Docker container compatibility
|
||||||
|
from mcp.server.fastmcp.server import TransportSecuritySettings
|
||||||
|
|
||||||
|
# Disable DNS rebinding protection to allow Docker container hostnames
|
||||||
|
# This allows containers to connect using service names like 'mcp-server'
|
||||||
|
mcp.settings.transport_security = TransportSecuritySettings(
|
||||||
|
enable_dns_rebinding_protection=False,
|
||||||
|
)
|
||||||
|
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
|
||||||
|
|
||||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||||
|
|
||||||
# Run the server with SSE transport for HTTP connections
|
# Run the server with SSE transport for HTTP connections
|
||||||
|
|||||||
@@ -50,9 +50,7 @@
|
|||||||
"entity2_name": "用户",
|
"entity2_name": "用户",
|
||||||
"entity2": {
|
"entity2": {
|
||||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
"name": "用户"
|
||||||
"name": "用户",
|
|
||||||
"id": "3d3896797b334572a80d57590026063d"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -62,9 +60,7 @@
|
|||||||
"entity2_name": "身份信息",
|
"entity2_name": "身份信息",
|
||||||
"entity2": {
|
"entity2": {
|
||||||
"description": "用于个人身份识别的数据",
|
"description": "用于个人身份识别的数据",
|
||||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
"name": "身份信息"
|
||||||
"name": "身份信息",
|
|
||||||
"id": "aa766a517e82490599a9b3af54cfd933"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -74,9 +70,7 @@
|
|||||||
"entity2_name": "6222023847595898",
|
"entity2_name": "6222023847595898",
|
||||||
"entity2": {
|
"entity2": {
|
||||||
"description": "用户的银行卡号码",
|
"description": "用户的银行卡号码",
|
||||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
"name": "6222023847595898"
|
||||||
"name": "6222023847595898",
|
|
||||||
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -88,9 +82,7 @@
|
|||||||
"entity_idx": 1,
|
"entity_idx": 1,
|
||||||
"aliases": ["上海办"],
|
"aliases": ["上海办"],
|
||||||
"description": "位于上海的工作办公场所",
|
"description": "位于上海的工作办公场所",
|
||||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
"name": "上海办公室"
|
||||||
"name": "上海办公室",
|
|
||||||
"id": "fb702ef695c14e14af3e56786bc8815b"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -101,9 +93,7 @@
|
|||||||
"entity2": {
|
"entity2": {
|
||||||
"aliases": ["京", "京城", "北平"],
|
"aliases": ["京", "京城", "北平"],
|
||||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
"name": "北京"
|
||||||
"name": "北京",
|
|
||||||
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -113,9 +103,7 @@
|
|||||||
"entity2_name": "身份证号",
|
"entity2_name": "身份证号",
|
||||||
"entity2": {
|
"entity2": {
|
||||||
"description": "中华人民共和国公民的身份号码",
|
"description": "中华人民共和国公民的身份号码",
|
||||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
"name": "身份证号"
|
||||||
"name": "身份证号",
|
|
||||||
"id": "3e5f920645b2404fadb0e9ff60d1306e"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -269,8 +269,6 @@ class ReflectionEngine:
|
|||||||
# # 检查是否真的有冲突
|
# # 检查是否真的有冲突
|
||||||
conflicts_found=''
|
conflicts_found=''
|
||||||
|
|
||||||
# 记录冲突数据
|
|
||||||
await self._log_data("conflict", conflict_data)
|
|
||||||
conflicts_found=''
|
conflicts_found=''
|
||||||
# 3. 解决冲突
|
# 3. 解决冲突
|
||||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||||
@@ -288,8 +286,6 @@ class ReflectionEngine:
|
|||||||
conflicts_resolved = len(solved_data)
|
conflicts_resolved = len(solved_data)
|
||||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||||
|
|
||||||
# 记录解决方案
|
|
||||||
await self._log_data("solved_data", solved_data)
|
|
||||||
|
|
||||||
# 4. 应用反思结果(更新记忆库)
|
# 4. 应用反思结果(更新记忆库)
|
||||||
memories_updated=await self._apply_reflection_results(solved_data)
|
memories_updated=await self._apply_reflection_results(solved_data)
|
||||||
@@ -390,14 +386,7 @@ class ReflectionEngine:
|
|||||||
memory_verifies.append(item['memory_verify'])
|
memory_verifies.append(item['memory_verify'])
|
||||||
result_data['memory_verifies'] = memory_verifies
|
result_data['memory_verifies'] = memory_verifies
|
||||||
result_data['quality_assessments'] = quality_assessments
|
result_data['quality_assessments'] = quality_assessments
|
||||||
|
conflicts_found=''
|
||||||
# 检查是否真的有冲突
|
|
||||||
has_conflict = conflict_data[0].get('conflict', False)
|
|
||||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
|
||||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
|
||||||
|
|
||||||
# 记录冲突数据
|
|
||||||
await self._log_data("conflict", conflict_data)
|
|
||||||
|
|
||||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||||
cleaned_conflict_data = []
|
cleaned_conflict_data = []
|
||||||
@@ -407,6 +396,7 @@ class ReflectionEngine:
|
|||||||
'conflict': item['conflict']
|
'conflict': item['conflict']
|
||||||
}
|
}
|
||||||
cleaned_conflict_data.append(cleaned_item)
|
cleaned_conflict_data.append(cleaned_item)
|
||||||
|
|
||||||
# 3. 解决冲突
|
# 3. 解决冲突
|
||||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
||||||
if not solved_data:
|
if not solved_data:
|
||||||
@@ -645,26 +635,7 @@ class ReflectionEngine:
|
|||||||
success_count = await neo4j_data(changes)
|
success_count = await neo4j_data(changes)
|
||||||
return success_count
|
return success_count
|
||||||
|
|
||||||
async def _log_data(self, label: str, data: Any) -> None:
|
|
||||||
"""
|
|
||||||
记录数据到文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
label: 数据标签
|
|
||||||
data: 要记录的数据
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _write():
|
|
||||||
try:
|
|
||||||
with open("reflexion_data.json", "a", encoding="utf-8") as f:
|
|
||||||
f.write(f"### {label} ###\n")
|
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
|
||||||
f.write("\n\n")
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"记录数据失败: {e}")
|
|
||||||
|
|
||||||
# 在后台线程中执行写入,避免阻塞事件循环
|
|
||||||
await asyncio.to_thread(_write)
|
|
||||||
|
|
||||||
# 基于时间的反思方法
|
# 基于时间的反思方法
|
||||||
async def time_based_reflection(
|
async def time_based_reflection(
|
||||||
@@ -753,4 +724,3 @@ class ReflectionEngine:
|
|||||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,12 @@
|
|||||||
- **日期属性冲突**: 同一人的生日等单值属性出现多值
|
- **日期属性冲突**: 同一人的生日等单值属性出现多值
|
||||||
- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业)
|
- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业)
|
||||||
- **互斥重叠**: 同一时间出现在不同地点等互斥事件
|
- **互斥重叠**: 同一时间出现在不同地点等互斥事件
|
||||||
|
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
||||||
### 事实冲突
|
### 事实冲突
|
||||||
- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢)
|
- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢)
|
||||||
- **关系矛盾**: 同一实体在相同语境下的不同关系描述
|
- **关系矛盾**: 同一实体在相同语境下的不同关系描述
|
||||||
- **身份冲突**: 同一实体被赋予不同类型或角色
|
- **身份冲突**: 同一实体被赋予不同类型或角色
|
||||||
|
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
||||||
### 混合冲突
|
### 混合冲突
|
||||||
检测所有逻辑不一致或相互矛盾的记录。
|
检测所有逻辑不一致或相互矛盾的记录。
|
||||||
**检测原则**:
|
**检测原则**:
|
||||||
|
|||||||
@@ -171,7 +171,6 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**输出要求**:
|
**输出要求**:
|
||||||
- 只输出JSON,不添加解释文本
|
- 只输出JSON,不添加解释文本
|
||||||
- 使用标准双引号,必要时转义
|
- 使用标准双引号,必要时转义
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import List, Dict, Any
|
|||||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||||
|
|
||||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||||
baseline: str = "TIME",
|
baseline: str = "TIME",
|
||||||
memory_verify: bool = False,quality_assessment:bool = False,
|
memory_verify: bool = False,quality_assessment:bool = False,
|
||||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||||
@@ -16,7 +16,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
evaluate_data: The data to evaluate
|
evaluate_data: The data to evaluate
|
||||||
schema: The JSON schema to use for the output.
|
schema: The Pydantic model class or JSON schema to use for the output.
|
||||||
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
|
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
|
||||||
memory_verify: Whether to enable memory verification for privacy detection
|
memory_verify: Whether to enable memory verification for privacy detection
|
||||||
|
|
||||||
@@ -25,9 +25,17 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
|
|||||||
"""
|
"""
|
||||||
template = prompt_env.get_template("evaluate.jinja2")
|
template = prompt_env.get_template("evaluate.jinja2")
|
||||||
|
|
||||||
|
# Convert Pydantic model to JSON schema if needed
|
||||||
|
if hasattr(schema, 'model_json_schema'):
|
||||||
|
json_schema = schema.model_json_schema()
|
||||||
|
elif hasattr(schema, 'schema'):
|
||||||
|
json_schema = schema.schema()
|
||||||
|
else:
|
||||||
|
json_schema = schema
|
||||||
|
|
||||||
rendered_prompt = template.render(
|
rendered_prompt = template.render(
|
||||||
evaluate_data=evaluate_data,
|
evaluate_data=evaluate_data,
|
||||||
json_schema=schema,
|
json_schema=json_schema,
|
||||||
baseline=baseline,
|
baseline=baseline,
|
||||||
memory_verify=memory_verify,
|
memory_verify=memory_verify,
|
||||||
quality_assessment=quality_assessment,
|
quality_assessment=quality_assessment,
|
||||||
@@ -36,14 +44,15 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
|
|||||||
)
|
)
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|
||||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
|
||||||
|
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: The data to reflex on.
|
data: The data to reflex on.
|
||||||
schema: The JSON schema to use for the output.
|
schema: The Pydantic model class or JSON schema to use for the output.
|
||||||
baseline: The baseline type for conflict resolution.
|
baseline: The baseline type for conflict resolution.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -51,7 +60,15 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any],
|
|||||||
"""
|
"""
|
||||||
template = prompt_env.get_template("reflexion.jinja2")
|
template = prompt_env.get_template("reflexion.jinja2")
|
||||||
|
|
||||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
# Convert Pydantic model to JSON schema if needed
|
||||||
|
if hasattr(schema, 'model_json_schema'):
|
||||||
|
json_schema = schema.model_json_schema()
|
||||||
|
elif hasattr(schema, 'schema'):
|
||||||
|
json_schema = schema.schema()
|
||||||
|
else:
|
||||||
|
json_schema = schema
|
||||||
|
|
||||||
|
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||||
baseline=baseline,memory_verify=memory_verify,
|
baseline=baseline,memory_verify=memory_verify,
|
||||||
statement_databasets=statement_databasets,language_type=language_type)
|
statement_databasets=statement_databasets,language_type=language_type)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from io import BytesIO
|
|||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
@@ -439,7 +439,7 @@ class MinerUParser(RAGPdfParser):
|
|||||||
def parse_pdf(
|
def parse_pdf(
|
||||||
self,
|
self,
|
||||||
filepath: str | PathLike[str],
|
filepath: str | PathLike[str],
|
||||||
binary: BytesIO | bytes,
|
binary: Optional[Union[BytesIO, bytes]] = None,
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
*,
|
*,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
|
|||||||
@@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
self._client.indices.create(index=self._collection_name, body=index_mapping)
|
self._client.indices.create(index=self._collection_name, body=index_mapping)
|
||||||
|
|
||||||
|
|
||||||
class ElasticSearchVectorFactory(ABC):
|
class ElasticSearchVectorFactory:
|
||||||
def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector:
|
@staticmethod
|
||||||
|
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
|
||||||
collection_name = f"Vector_index_{knowledge.id}_Node"
|
collection_name = f"Vector_index_{knowledge.id}_Node"
|
||||||
|
|
||||||
# Use regular Elasticsearch with config values
|
# Use regular Elasticsearch with config values
|
||||||
@@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.embedding and knowledge.reranker:
|
if knowledge.embedding is None:
|
||||||
return ElasticSearchVector(
|
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
||||||
index_name=collection_name,
|
if knowledge.reranker is None:
|
||||||
config=ElasticSearchConfig(**config_dict),
|
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
|
||||||
embedding_config=knowledge.embedding.api_keys[0],
|
|
||||||
reranker_config=knowledge.reranker.api_keys[0]
|
return ElasticSearchVector(
|
||||||
)
|
index_name=collection_name,
|
||||||
else:
|
config=ElasticSearchConfig(**config_dict),
|
||||||
if knowledge.embedding is None:
|
embedding_config=knowledge.embedding.api_keys[0],
|
||||||
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
reranker_config=knowledge.reranker.api_keys[0]
|
||||||
if knowledge.reranker is None:
|
)
|
||||||
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.end import EndNode
|
from app.core.workflow.nodes.end import EndNode
|
||||||
from app.core.workflow.nodes.if_else import IfElseNode
|
from app.core.workflow.nodes.if_else import IfElseNode
|
||||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||||
from app.core.workflow.nodes.llm import LLMNode
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
@@ -26,6 +26,6 @@ __all__ = [
|
|||||||
"EndNode",
|
"EndNode",
|
||||||
"NodeFactory",
|
"NodeFactory",
|
||||||
"WorkflowNode",
|
"WorkflowNode",
|
||||||
# "KnowledgeRetrievalNode",
|
"KnowledgeRetrievalNode",
|
||||||
"AssignerNode",
|
"AssignerNode",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,21 +1,32 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||||
|
|
||||||
|
|
||||||
class AssignerNodeConfig(BaseNodeConfig):
|
class AssignmentItem(BaseModel):
|
||||||
|
"""
|
||||||
|
Single assignment definition.
|
||||||
|
"""
|
||||||
|
|
||||||
variable_selector: str | list[str] = Field(
|
variable_selector: str | list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Variables to be assigned",
|
description="Target variable name(s) to assign",
|
||||||
)
|
)
|
||||||
|
|
||||||
operation: AssignmentOperator = Field(
|
operation: AssignmentOperator = Field(
|
||||||
...,
|
...,
|
||||||
description="Operator to assign",
|
description="Assignment operator",
|
||||||
)
|
)
|
||||||
|
|
||||||
value: str | list[str] = Field(
|
value: str | list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Values to assign",
|
description="Value(s) to assign to the variable(s)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssignerNodeConfig(BaseNodeConfig):
|
||||||
|
assignments: list[AssignmentItem] = Field(
|
||||||
|
...,
|
||||||
|
description="List of variable assignment definitions",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,52 +29,52 @@ class AssignerNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
|
for assignment in self.typed_config.assignments:
|
||||||
|
# Get the target variable selector (e.g., "conv.test")
|
||||||
|
variable_selector = assignment.variable_selector
|
||||||
|
if isinstance(variable_selector, str):
|
||||||
|
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||||
|
variable_selector = variable_selector.split('.')
|
||||||
|
|
||||||
# Get the target variable selector (e.g., "conv.test")
|
# Only conversation variables ('conv') are allowed
|
||||||
variable_selector = self.typed_config.variable_selector
|
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||||
if isinstance(variable_selector, str):
|
raise ValueError("Only conversation variables can be assigned.")
|
||||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
|
||||||
variable_selector = variable_selector.split('.')
|
|
||||||
|
|
||||||
# Only conversation variables ('conv') are allowed
|
# Get the value or expression to assign
|
||||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
value = assignment.value
|
||||||
raise ValueError("Only conversation variables can be assigned.")
|
if isinstance(value, list):
|
||||||
|
value = '.'.join(value)
|
||||||
|
value = ExpressionEvaluator.evaluate(
|
||||||
|
expression=value,
|
||||||
|
variables=pool.get_all_conversation_vars(),
|
||||||
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
|
system_vars=pool.get_all_system_vars(),
|
||||||
|
)
|
||||||
|
|
||||||
# Get the value or expression to assign
|
# Select the appropriate assignment operator instance based on the target variable type
|
||||||
value = self.typed_config.value
|
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||||
if isinstance(value, list):
|
pool, variable_selector, value
|
||||||
value = '.'.join(value)
|
)
|
||||||
value = ExpressionEvaluator.evaluate(
|
|
||||||
expression=value,
|
|
||||||
variables=pool.get_all_conversation_vars(),
|
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
|
||||||
system_vars=pool.get_all_system_vars(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select the appropriate assignment operator instance based on the target variable type
|
# Execute the configured assignment operation
|
||||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
match assignment.operation:
|
||||||
pool, variable_selector, value
|
case AssignmentOperator.ASSIGN:
|
||||||
)
|
operator.assign()
|
||||||
|
case AssignmentOperator.CLEAR:
|
||||||
# Execute the configured assignment operation
|
operator.clear()
|
||||||
match self.typed_config.operation:
|
case AssignmentOperator.ADD:
|
||||||
case AssignmentOperator.ASSIGN:
|
operator.add()
|
||||||
operator.assign()
|
case AssignmentOperator.SUBTRACT:
|
||||||
case AssignmentOperator.CLEAR:
|
operator.subtract()
|
||||||
operator.clear()
|
case AssignmentOperator.MULTIPLY:
|
||||||
case AssignmentOperator.ADD:
|
operator.multiply()
|
||||||
operator.add()
|
case AssignmentOperator.DIVIDE:
|
||||||
case AssignmentOperator.SUBTRACT:
|
operator.divide()
|
||||||
operator.subtract()
|
case AssignmentOperator.APPEND:
|
||||||
case AssignmentOperator.MULTIPLY:
|
operator.append()
|
||||||
operator.multiply()
|
case AssignmentOperator.REMOVE_FIRST:
|
||||||
case AssignmentOperator.DIVIDE:
|
operator.remove_first()
|
||||||
operator.divide()
|
case AssignmentOperator.REMOVE_LAST:
|
||||||
case AssignmentOperator.APPEND:
|
operator.remove_last()
|
||||||
operator.append()
|
case _:
|
||||||
case AssignmentOperator.REMOVE_FIRST:
|
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||||
operator.remove_first()
|
|
||||||
case AssignmentOperator.REMOVE_LAST:
|
|
||||||
operator.remove_last()
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
|||||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||||
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -30,6 +30,6 @@ __all__ = [
|
|||||||
"AgentNodeConfig",
|
"AgentNodeConfig",
|
||||||
"TransformNodeConfig",
|
"TransformNodeConfig",
|
||||||
"IfElseNodeConfig",
|
"IfElseNodeConfig",
|
||||||
# "KnowledgeRetrievalNodeConfig",
|
"KnowledgeRetrievalNodeConfig",
|
||||||
"AssignerNodeConfig",
|
"AssignerNodeConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -9,28 +9,29 @@ import re
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EndNode(BaseNode):
|
class EndNode(BaseNode):
|
||||||
"""End 节点
|
"""End 节点
|
||||||
|
|
||||||
工作流的结束节点,根据配置的模板输出最终结果。
|
工作流的结束节点,根据配置的模板输出最终结果。
|
||||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> str:
|
async def execute(self, state: WorkflowState) -> str:
|
||||||
"""执行 end 节点业务逻辑
|
"""执行 end 节点业务逻辑
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 工作流状态
|
state: 工作流状态
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
最终输出字符串
|
最终输出字符串
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||||
|
|
||||||
# 获取配置的输出模板
|
# 获取配置的输出模板
|
||||||
output_template = self.config.get("output")
|
output_template = self.config.get("output")
|
||||||
|
|
||||||
@@ -39,11 +40,11 @@ class EndNode(BaseNode):
|
|||||||
output = self._render_template(output_template, state)
|
output = self._render_template(output_template, state)
|
||||||
else:
|
else:
|
||||||
output = "工作流已完成"
|
output = "工作流已完成"
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
# 统计信息(用于日志)
|
||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
total_nodes = len(node_outputs)
|
total_nodes = len(node_outputs)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -127,24 +128,26 @@ class EndNode(BaseNode):
|
|||||||
return parts
|
return parts
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState):
|
async def execute_stream(self, state: WorkflowState):
|
||||||
"""流式执行 end 节点业务逻辑
|
"""Execute End node business logic (streaming)
|
||||||
|
|
||||||
智能输出策略:
|
Smart output strategy:
|
||||||
1. 检测模板中是否引用了直接上游节点
|
1. Check if template references a direct upstream LLM node
|
||||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
2. If yes, only output the part AFTER that reference (suffix)
|
||||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
3. Prefix and LLM content have already been sent during LLM node streaming
|
||||||
|
|
||||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
Note: Only LLM nodes get this special treatment. Other node types output normally.
|
||||||
- 直接上游节点是 llm_qa
|
|
||||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||||
- LLM 内容在 LLM 节点流式输出
|
- Direct upstream LLM node is llm_qa
|
||||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
|
||||||
|
- LLM content was streamed during LLM node execution
|
||||||
|
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 工作流状态
|
state: Workflow state
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
完成标记
|
Completion marker
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||||
|
|
||||||
@@ -156,39 +159,45 @@ class EndNode(BaseNode):
|
|||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
# 找到直接上游节点
|
# Find direct upstream LLM nodes
|
||||||
direct_upstream_nodes = []
|
direct_upstream_llm_nodes = []
|
||||||
for edge in self.workflow_config.get("edges", []):
|
for edge in self.workflow_config.get("edges", []):
|
||||||
if edge.get("target") == self.node_id:
|
if edge.get("target") == self.node_id:
|
||||||
source_node_id = edge.get("source")
|
source_node_id = edge.get("source")
|
||||||
direct_upstream_nodes.append(source_node_id)
|
# Check if the source node is an LLM node
|
||||||
|
for node in self.workflow_config.get("nodes", []):
|
||||||
|
print("="*50)
|
||||||
|
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||||
|
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||||
|
direct_upstream_llm_nodes.append(source_node_id)
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
|
||||||
|
|
||||||
# 解析模板部分
|
# Parse template parts
|
||||||
parts = self._parse_template_parts(output_template, state)
|
parts = self._parse_template_parts(output_template, state)
|
||||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||||
|
|
||||||
# 找到第一个引用直接上游节点的动态引用
|
# Find the first reference to a direct upstream LLM node
|
||||||
upstream_ref_index = None
|
upstream_llm_ref_index = None
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
|
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
|
||||||
upstream_ref_index = i
|
upstream_llm_ref_index = i
|
||||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
|
||||||
break
|
break
|
||||||
|
|
||||||
if upstream_ref_index is None:
|
if upstream_llm_ref_index is None:
|
||||||
# 没有引用直接上游节点,输出完整模板内容
|
# No reference to direct upstream LLM node, output complete template content
|
||||||
output = self._render_template(output_template, state)
|
output = self._render_template(output_template, state)
|
||||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
|
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
||||||
|
|
||||||
# 通过 writer 发送完整内容(作为一个 message chunk)
|
# Send complete content via writer (as a single message chunk)
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
writer({
|
writer({
|
||||||
"type": "message", # End 节点的输出使用 message 类型
|
"type": "message", # End node output uses message type
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"chunk": output,
|
"chunk": output,
|
||||||
"full_content": output,
|
"full_content": output,
|
||||||
@@ -197,17 +206,17 @@ class EndNode(BaseNode):
|
|||||||
})
|
})
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||||
|
|
||||||
# yield 完成标记
|
# yield completion marker
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||||
|
|
||||||
# 收集后缀部分
|
# Collect suffix parts
|
||||||
suffix_parts = []
|
suffix_parts = []
|
||||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}")
|
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}")
|
||||||
for i in range(upstream_ref_index + 1, len(parts)):
|
for i in range(upstream_llm_ref_index + 1, len(parts)):
|
||||||
part = parts[i]
|
part = parts[i]
|
||||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
||||||
if part["type"] == "static":
|
if part["type"] == "static":
|
||||||
@@ -219,7 +228,7 @@ class EndNode(BaseNode):
|
|||||||
# Other dynamic references (if there are multiple references)
|
# Other dynamic references (if there are multiple references)
|
||||||
node_id = part["node_id"]
|
node_id = part["node_id"]
|
||||||
field = part["field"]
|
field = part["field"]
|
||||||
|
|
||||||
# Use VariablePool to get variable value
|
# Use VariablePool to get variable value
|
||||||
pool = self.get_variable_pool(state)
|
pool = self.get_variable_pool(state)
|
||||||
try:
|
try:
|
||||||
@@ -232,7 +241,7 @@ class EndNode(BaseNode):
|
|||||||
|
|
||||||
# Convert to string if not None
|
# Convert to string if not None
|
||||||
suffix_parts.append(str(content) if content is not None else "")
|
suffix_parts.append(str(content) if content is not None else "")
|
||||||
|
|
||||||
# 拼接后缀
|
# 拼接后缀
|
||||||
suffix = "".join(suffix_parts)
|
suffix = "".join(suffix_parts)
|
||||||
|
|
||||||
@@ -261,8 +270,8 @@ class EndNode(BaseNode):
|
|||||||
})
|
})
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}")
|
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
total_nodes = len(node_outputs)
|
total_nodes = len(node_outputs)
|
||||||
|
|||||||
4
api/app/core/workflow/nodes/knowledge/__init__.py
Normal file
4
api/app/core/workflow/nodes/knowledge/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||||
|
from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"]
|
||||||
38
api/app/core/workflow/nodes/knowledge/config.py
Normal file
38
api/app/core/workflow/nodes/knowledge/config.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
from app.schemas.chunk_schema import RetrieveType
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||||
|
query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Search query string"
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_ids: list[UUID] = Field(
|
||||||
|
...,
|
||||||
|
description="Knowledge base IDs"
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_threshold: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
description="Knowledge base similarity threshold"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_similarity_weight: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
description="Knowledge base vector similarity weight"
|
||||||
|
)
|
||||||
|
|
||||||
|
top_k: int = Field(
|
||||||
|
default=4,
|
||||||
|
description="Knowledge base top k"
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieve_type: RetrieveType = Field(
|
||||||
|
default=RetrieveType.PARTICIPLE,
|
||||||
|
description="Retrieve type"
|
||||||
|
)
|
||||||
93
api/app/core/workflow/nodes/knowledge/node.py
Normal file
93
api/app/core/workflow/nodes/knowledge/node.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.models import knowledge_model, knowledgeshare_model
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
|
from app.schemas.chunk_schema import RetrieveType
|
||||||
|
from app.services import knowledge_service, knowledgeshare_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
query = self._render_template(self.typed_config.query, state)
|
||||||
|
with get_db_context() as db:
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
if share_ids:
|
||||||
|
filters = [
|
||||||
|
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
|
||||||
|
]
|
||||||
|
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
existing_ids.extend(items)
|
||||||
|
|
||||||
|
if not existing_ids:
|
||||||
|
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||||
|
|
||||||
|
kb_id = existing_ids[0]
|
||||||
|
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||||
|
indices = ",".join(uuid_strs)
|
||||||
|
|
||||||
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
match self.typed_config.retrieve_type:
|
||||||
|
case RetrieveType.PARTICIPLE:
|
||||||
|
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.similarity_threshold)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
|
case RetrieveType.SEMANTIC:
|
||||||
|
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.vector_similarity_weight)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.vector_similarity_weight)
|
||||||
|
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.similarity_threshold)
|
||||||
|
# Efficient deduplication
|
||||||
|
seen_ids = set()
|
||||||
|
unique_rs = []
|
||||||
|
for doc in rs1 + rs2:
|
||||||
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
|
unique_rs.append(doc)
|
||||||
|
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||||
from app.core.workflow.nodes.agent import AgentNode
|
from app.core.workflow.nodes.agent import AgentNode
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.end import EndNode
|
from app.core.workflow.nodes.end import EndNode
|
||||||
@@ -29,7 +29,7 @@ WorkflowNode = Union[
|
|||||||
AgentNode,
|
AgentNode,
|
||||||
TransformNode,
|
TransformNode,
|
||||||
AssignerNode,
|
AssignerNode,
|
||||||
# KnowledgeRetrievalNode,
|
KnowledgeRetrievalNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class NodeFactory:
|
|||||||
NodeType.AGENT: AgentNode,
|
NodeType.AGENT: AgentNode,
|
||||||
NodeType.TRANSFORM: TransformNode,
|
NodeType.TRANSFORM: TransformNode,
|
||||||
NodeType.IF_ELSE: IfElseNode,
|
NodeType.IF_ELSE: IfElseNode,
|
||||||
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||||
NodeType.ASSIGNER: AssignerNode,
|
NodeType.ASSIGNER: AssignerNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class StartNode(BaseNode):
|
|||||||
|
|
||||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
"""初始化 Start 节点
|
"""初始化 Start 节点
|
||||||
|
|
||||||
@@ -32,10 +32,10 @@ class StartNode(BaseNode):
|
|||||||
workflow_config: 工作流配置
|
workflow_config: 工作流配置
|
||||||
"""
|
"""
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
|
|
||||||
# 解析并验证配置
|
# 解析并验证配置
|
||||||
self.typed_config = StartNodeConfig(**self.config)
|
self.typed_config = StartNodeConfig(**self.config)
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行 start 节点业务逻辑
|
"""执行 start 节点业务逻辑
|
||||||
|
|
||||||
@@ -48,13 +48,13 @@ class StartNode(BaseNode):
|
|||||||
包含系统参数、会话变量和自定义变量的字典
|
包含系统参数、会话变量和自定义变量的字典
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||||
|
|
||||||
# 创建变量池实例(在方法内复用)
|
# 创建变量池实例(在方法内复用)
|
||||||
pool = self.get_variable_pool(state)
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
# 处理自定义变量(传入 pool 避免重复创建)
|
# 处理自定义变量(传入 pool 避免重复创建)
|
||||||
custom_vars = self._process_custom_variables(pool)
|
custom_vars = self._process_custom_variables(pool)
|
||||||
|
|
||||||
# 返回业务数据(包含自定义变量)
|
# 返回业务数据(包含自定义变量)
|
||||||
result = {
|
result = {
|
||||||
"message": pool.get("sys.message"),
|
"message": pool.get("sys.message"),
|
||||||
@@ -64,14 +64,14 @@ class StartNode(BaseNode):
|
|||||||
"user_id": pool.get("sys.user_id"),
|
"user_id": pool.get("sys.user_id"),
|
||||||
**custom_vars # 自定义变量作为节点输出的一部分
|
**custom_vars # 自定义变量作为节点输出的一部分
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"节点 {self.node_id} (Start) 执行完成,"
|
f"节点 {self.node_id} (Start) 执行完成,"
|
||||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||||
"""处理自定义变量
|
"""处理自定义变量
|
||||||
|
|
||||||
@@ -88,34 +88,33 @@ class StartNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
# 获取输入数据中的自定义变量
|
# 获取输入数据中的自定义变量
|
||||||
input_variables = pool.get("sys.input_variables", default={})
|
input_variables = pool.get("sys.input_variables", default={})
|
||||||
|
|
||||||
processed = {}
|
processed = {}
|
||||||
|
|
||||||
# 遍历配置的变量定义
|
# 遍历配置的变量定义
|
||||||
for var_def in self.typed_config.variables:
|
for var_def in self.typed_config.variables:
|
||||||
var_name = var_def.name
|
var_name = var_def.name
|
||||||
|
|
||||||
# 检查变量是否存在
|
# 检查变量是否存在
|
||||||
if var_name in input_variables:
|
if var_name in input_variables:
|
||||||
# 使用用户提供的值
|
# 使用用户提供的值
|
||||||
processed[var_name] = input_variables[var_name]
|
processed[var_name] = input_variables[var_name]
|
||||||
|
|
||||||
elif var_def.required:
|
elif var_def.required:
|
||||||
# 必需变量缺失
|
# 必需变量缺失
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"缺少必需的输入变量: {var_name}"
|
f"缺少必需的输入变量: {var_name}"
|
||||||
+ (f" ({var_def.description})" if var_def.description else "")
|
+ (f" ({var_def.description})" if var_def.description else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
elif var_def.default is not None:
|
elif var_def.default is not None:
|
||||||
# 使用默认值
|
# 使用默认值
|
||||||
processed[var_name] = var_def.default
|
processed[var_name] = var_def.default
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)
|
"""提取输入数据(用于记录)
|
||||||
@@ -127,7 +126,7 @@ class StartNode(BaseNode):
|
|||||||
输入数据字典
|
输入数据字典
|
||||||
"""
|
"""
|
||||||
pool = self.get_variable_pool(state)
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"execution_id": pool.get("sys.execution_id"),
|
"execution_id": pool.get("sys.execution_id"),
|
||||||
"conversation_id": pool.get("sys.conversation_id"),
|
"conversation_id": pool.get("sys.conversation_id"),
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def get_knowledges_paginated(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_chunded_knowledgeids(
|
def get_chunked_knowledgeids(
|
||||||
db: Session,
|
db: Session,
|
||||||
filters: list
|
filters: list
|
||||||
) -> list:
|
) -> list:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class BaseDataSchema(BaseModel):
|
|||||||
# 保持原有必需字段为可选,以兼容不同数据源
|
# 保持原有必需字段为可选,以兼容不同数据源
|
||||||
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
|
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
|
||||||
statement: Optional[str] = Field(None, description="The statement text.")
|
statement: Optional[str] = Field(None, description="The statement text.")
|
||||||
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
|
created_at: Optional[str] = Field(None, description="The creation timestamp in ISO 8601 format.")
|
||||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||||
description: Optional[str] = Field(None, description="The description of the data entry.")
|
description: Optional[str] = Field(None, description="The description of the data entry.")
|
||||||
|
|
||||||
@@ -46,6 +46,14 @@ class BaseDataSchema(BaseModel):
|
|||||||
relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.")
|
relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.")
|
||||||
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def _set_default_created_at(cls, v):
|
||||||
|
"""Set default created_at if missing"""
|
||||||
|
if isinstance(v, dict) and v.get("created_at") is None:
|
||||||
|
from datetime import datetime
|
||||||
|
v["created_at"] = datetime.now().isoformat()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class QualityAssessmentSchema(BaseModel):
|
class QualityAssessmentSchema(BaseModel):
|
||||||
"""Schema for memory quality assessment results."""
|
"""Schema for memory quality assessment results."""
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def get_chunded_knowledgeids(
|
|||||||
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
|
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
items = knowledge_repository.get_chunded_knowledgeids(
|
items = knowledge_repository.get_chunked_knowledgeids(
|
||||||
db=db,
|
db=db,
|
||||||
filters=filters
|
filters=filters
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,11 +44,11 @@ nodes:
|
|||||||
- role: user
|
- role: user
|
||||||
content: "{{ sys.message }}"
|
content: "{{ sys.message }}"
|
||||||
|
|
||||||
model_id: gpt-3.5-turbo
|
model_id: null
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
max_tokens: 1000
|
max_tokens: 1000
|
||||||
position:
|
position:
|
||||||
x: 300
|
x: 500
|
||||||
y: 100
|
y: 100
|
||||||
|
|
||||||
- id: end
|
- id: end
|
||||||
@@ -57,7 +57,7 @@ nodes:
|
|||||||
config:
|
config:
|
||||||
output: "{{ llm_qa.output }}"
|
output: "{{ llm_qa.output }}"
|
||||||
position:
|
position:
|
||||||
x: 500
|
x: 900
|
||||||
y: 100
|
y: 100
|
||||||
|
|
||||||
edges:
|
edges:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ version: '3.9'
|
|||||||
services:
|
services:
|
||||||
# MCP Server - standalone service
|
# MCP Server - standalone service
|
||||||
mcp-server:
|
mcp-server:
|
||||||
image: redbear-mem:latest
|
image: redbear-mem-open:latest
|
||||||
container_name: mcp-server
|
container_name: mcp-server
|
||||||
ports:
|
ports:
|
||||||
- "8081:8081" # MCP server port
|
- "8081:8081" # MCP server port
|
||||||
@@ -28,14 +28,14 @@ services:
|
|||||||
|
|
||||||
# FastAPI application - connects to MCP server
|
# FastAPI application - connects to MCP server
|
||||||
api:
|
api:
|
||||||
image: redbear-mem:latest
|
image: redbear-mem-open:latest
|
||||||
container_name: api
|
container_name: api
|
||||||
ports:
|
ports:
|
||||||
- "8002:8000"
|
- "8002:8000"
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
- MCP_SERVER_URL=http://mcp-server:8081
|
- MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
|
||||||
- SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces
|
- SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces
|
||||||
volumes:
|
volumes:
|
||||||
- ./files:/files
|
- ./files:/files
|
||||||
@@ -51,12 +51,12 @@ services:
|
|||||||
|
|
||||||
# Celery worker - connects to MCP server
|
# Celery worker - connects to MCP server
|
||||||
worker:
|
worker:
|
||||||
image: redbear-mem:latest
|
image: redbear-mem-open:latest
|
||||||
container_name: worker
|
container_name: worker
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
- MCP_SERVER_URL=http://mcp-server:8081
|
- MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name
|
||||||
volumes:
|
volumes:
|
||||||
- ./files:/files
|
- ./files:/files
|
||||||
- /etc/localtime:/etc/localtime:ro
|
- /etc/localtime:/etc/localtime:ro
|
||||||
|
|||||||
@@ -71,6 +71,18 @@ ENABLE_SINGLE_SESSION=
|
|||||||
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
|
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
|
||||||
FILE_PATH=/files
|
FILE_PATH=/files
|
||||||
|
|
||||||
|
# RAG Setting
|
||||||
|
DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||||
|
HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
MINERU_EXECUTABLE=mineru
|
||||||
|
MINERU_APISERVER=http://host.docker.internal:9987
|
||||||
|
MINERU_OUTPUT_DIR=/files
|
||||||
|
MINERU_BACKEND=pipeline
|
||||||
|
MINERU_DELETE_OUTPUT=1
|
||||||
|
TEXTLN_APISERVER=https://api.textin.com/ai/service/v1/pdf_to_markdown
|
||||||
|
TEXTLN_APP_ID=
|
||||||
|
TEXTLN_SECRET_CODE=
|
||||||
|
|
||||||
# VOLC ASR
|
# VOLC ASR
|
||||||
VOLC_APP_KEY=
|
VOLC_APP_KEY=
|
||||||
VOLC_ACCESS_KEY=
|
VOLC_ACCESS_KEY=
|
||||||
|
|||||||
Reference in New Issue
Block a user