Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -151,6 +151,9 @@ class Settings:
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
|
||||
@@ -3,53 +3,43 @@
|
||||
"source_data": [
|
||||
{
|
||||
"statement_name": "用户是2023年春天去北京工作的。",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "62beac695b1346f4871740a45db88782"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户后来基本一直都在北京上班。",
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。",
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从来没有长期离开过北京。",
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a"
|
||||
},
|
||||
{
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。",
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的银行卡号是6222023847595898。",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。",
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962"
|
||||
}
|
||||
],
|
||||
"databasets": [
|
||||
@@ -57,24 +47,11 @@
|
||||
"entity1_name": "Person",
|
||||
"description": "表示人类个体的通用类型",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "用户",
|
||||
"entity2": {
|
||||
"entity_idx": 0,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Person",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "用户",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3d3896797b334572a80d57590026063d"
|
||||
}
|
||||
},
|
||||
@@ -82,24 +59,11 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份信息",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用于个人身份识别的数据",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Information",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份信息",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "aa766a517e82490599a9b3af54cfd933"
|
||||
}
|
||||
},
|
||||
@@ -107,24 +71,11 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "6222023847595898",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用户的银行卡号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Numeric",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "6222023847595898",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
|
||||
}
|
||||
},
|
||||
@@ -132,25 +83,13 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "上海办公室",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["上海办"],
|
||||
"connect_strength": "Strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "位于上海的工作办公场所",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "上海办公室",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "fb702ef695c14e14af3e56786bc8815b"
|
||||
}
|
||||
},
|
||||
@@ -158,25 +97,12 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "北京",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["京", "京城", "北平"],
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "北京",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
|
||||
}
|
||||
},
|
||||
@@ -184,24 +110,11 @@
|
||||
"entity1_name": "11010119950308123X",
|
||||
"description": "具体的身份证号码值",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份证号",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"description": "中华人民共和国公民的身份号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Identifier",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份证号",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3e5f920645b2404fadb0e9ff60d1306e"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,23 @@ import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.config.get_data import (
|
||||
extract_and_process_changes,
|
||||
get_data,
|
||||
get_data_statement,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
neo4j_query_all,
|
||||
neo4j_query_part,
|
||||
neo4j_statement_all,
|
||||
@@ -26,6 +41,10 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConflictResultSchema,
|
||||
ReflexionResultSchema,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
@@ -38,7 +57,9 @@ if not _root_logger.handlers:
|
||||
else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
data: str
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
@@ -66,6 +87,7 @@ class ReflectionConfig(BaseModel):
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@@ -126,6 +148,7 @@ class ReflectionEngine:
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
self._lazy_init_done = False
|
||||
|
||||
@@ -135,7 +158,6 @@ class ReflectionEngine:
|
||||
return
|
||||
|
||||
if self.neo4j_connector is None:
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
if self.llm_client is None:
|
||||
@@ -147,20 +169,35 @@ class ReflectionEngine:
|
||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
# from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
model_id = self.llm_client
|
||||
# model_id = self.llm_client
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(model_id)
|
||||
# self.llm_client = factory.get_llm_client(model_id)
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
}
|
||||
|
||||
model_config = get_model_config(self.llm_client)
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url"),
|
||||
timeout=model_config.get("timeout", 30),
|
||||
max_retries=model_config.get("max_retries", 2),
|
||||
extra_params=extra_params
|
||||
), type_=model_config.get("type"))
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
from app.core.memory.utils.config.get_data import get_data_statement
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
@@ -176,11 +213,9 @@ class ReflectionEngine:
|
||||
self.render_reflexion_prompt_func = render_reflexion_prompt
|
||||
|
||||
if self.conflict_schema is None:
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
self.conflict_schema = ConflictResultSchema
|
||||
|
||||
if self.reflexion_schema is None:
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
self.reflexion_schema = ReflexionResultSchema
|
||||
|
||||
if self.update_query is None:
|
||||
@@ -227,15 +262,12 @@ class ReflectionEngine:
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
# # 检查是否真的有冲突
|
||||
conflicts_found=''
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
conflicts_found=''
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
if not solved_data:
|
||||
@@ -256,7 +288,7 @@ class ReflectionEngine:
|
||||
await self._log_data("solved_data", solved_data)
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
memories_updated = await self._apply_reflection_results(solved_data)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
@@ -280,9 +312,60 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
|
||||
}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=translation_messages,
|
||||
response_model=TranslationResponse
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
reflexion_data=[]
|
||||
if data['memory_verifies']!=[]:
|
||||
for i in data['memory_verifies']:
|
||||
end_data={}
|
||||
end_data['has_privacy'] = i['has_privacy']
|
||||
privacy=i['privacy_types']
|
||||
privacy_types_=[]
|
||||
for pri in privacy:
|
||||
privacy_types_.append(await self.Translate(pri))
|
||||
end_data['privacy_types']=privacy_types_
|
||||
end_data['summary']=await self.Translate(i['summary'])
|
||||
memory_verifies.append(end_data)
|
||||
end_datas['memory_verifies']=memory_verifies
|
||||
|
||||
if data['quality_assessments']!=[]:
|
||||
for i in data['quality_assessments']:
|
||||
end_data = {}
|
||||
end_data['score']=i['score']
|
||||
end_data['summary'] = await self.Translate(i['summary'])
|
||||
quality_assessments.append(end_data)
|
||||
end_datas['quality_assessments'] = quality_assessments
|
||||
for i in data['reflexion_data']:
|
||||
end_data = {}
|
||||
end_data['reason'] = await self.Translate(i['reason'])
|
||||
end_data['solution'] = await self.Translate(i['solution'])
|
||||
reflexion_data.append(end_data)
|
||||
end_datas['reflexion_data'] = reflexion_data
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
quality_assessment=self.config.quality_assessment
|
||||
language_type=self.config.language_type
|
||||
|
||||
asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
@@ -291,20 +374,18 @@ class ReflectionEngine:
|
||||
|
||||
source_data, databasets = await self.extract_fields_from_json()
|
||||
result_data['baseline'] = self.config.baseline
|
||||
result_data[
|
||||
'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
print(item)
|
||||
quality_assessments.append(item['quality_assessment'])
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
@@ -314,8 +395,16 @@ class ReflectionEngine:
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
'data': item['data'],
|
||||
'conflict': item['conflict']
|
||||
}
|
||||
cleaned_conflict_data.append(cleaned_item)
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, source_data)
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
@@ -331,6 +420,14 @@ class ReflectionEngine:
|
||||
for result in item['results']:
|
||||
reflexion_data.append(result['reflexion'])
|
||||
result_data['reflexion_data'] = reflexion_data
|
||||
if memory_verifies_flag==False:
|
||||
result_data['memory_verifies']=[]
|
||||
if quality_assessment==False:
|
||||
result_data['quality_assessments']=[]
|
||||
|
||||
if language_type=='en':
|
||||
result_data=await self.extract_translation(result_data)
|
||||
print(time.time()-start_time,'----------')
|
||||
return result_data
|
||||
|
||||
|
||||
@@ -407,12 +504,13 @@ class ReflectionEngine:
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
quality_assessment = self.config.quality_assessment
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
@@ -422,7 +520,8 @@ class ReflectionEngine:
|
||||
self.config.baseline,
|
||||
memory_verify,
|
||||
quality_assessment,
|
||||
statement_databasets
|
||||
statement_databasets,
|
||||
language_type
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -485,6 +584,7 @@ class ReflectionEngine:
|
||||
memory_verify,
|
||||
statement_databasets
|
||||
)
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
@@ -537,7 +637,8 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
"""
|
||||
success_count = await neo4j_data(solved_data)
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
return success_count
|
||||
|
||||
async def _log_data(self, label: str, data: Any) -> None:
|
||||
@@ -644,5 +745,8 @@ class ReflectionEngine:
|
||||
execution_time=time_result.execution_time + fact_result.execution_time
|
||||
)
|
||||
else:
|
||||
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,20 @@ import uuid
|
||||
import logging
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from openai import BaseModel
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pydantic import model_validator, Field
|
||||
|
||||
from app.schemas.memory_storage_schema import SingleReflexionResultSchema
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from app.repositories.neo4j.neo4j_update import map_field_names
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
@@ -59,6 +73,14 @@ async def get_data(result):
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
EXCLUDE_FIELDS = {
|
||||
"user_id",
|
||||
"group_id",
|
||||
"entity_type",
|
||||
"connect_strength",
|
||||
"relationship_type",
|
||||
"apply_id"
|
||||
}
|
||||
neo4j_databasets=[]
|
||||
for item in result:
|
||||
filtered_item = {}
|
||||
@@ -73,14 +95,17 @@ async def get_data(result):
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = rel_filtered
|
||||
filtered_item[key] = value
|
||||
elif key == 'entity2' and value is not None:
|
||||
# 过滤entity2的name_embedding字段
|
||||
entity2_filtered = {}
|
||||
if hasattr(value, 'items'):
|
||||
for e_key, e_value in value.items():
|
||||
if 'name_embedding' not in e_key.lower():
|
||||
entity2_filtered[e_key] = e_value
|
||||
if e_key in EXCLUDE_FIELDS:
|
||||
continue
|
||||
if 'name_embedding' in e_key.lower():
|
||||
continue
|
||||
entity2_filtered[e_key] = e_value
|
||||
filtered_item[key] = entity2_filtered
|
||||
else:
|
||||
filtered_item[key] = value
|
||||
@@ -94,8 +119,57 @@ async def get_data_statement( result):
|
||||
neo4j_databasets.append(i)
|
||||
return neo4j_databasets
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
if isinstance(conflict, dict) and conflict.get("conflict") is False:
|
||||
v["resolved"] = None
|
||||
else:
|
||||
resolved = v.get("resolved")
|
||||
if isinstance(resolved, dict):
|
||||
orig = resolved.get("original_memory_id")
|
||||
mem = resolved.get("resolved_memory")
|
||||
if orig is None and (mem is None or mem == {}):
|
||||
v["resolved"] = None
|
||||
return v
|
||||
def extract_and_process_changes(DATA):
|
||||
"""提取并处理 change 字段"""
|
||||
all_changes = []
|
||||
for i, item in enumerate(DATA):
|
||||
try:
|
||||
result = ReflexionResultSchema(**item)
|
||||
for j, res in enumerate(result.results):
|
||||
if res.resolved and res.resolved.change:
|
||||
for k, change in enumerate(res.resolved.change):
|
||||
change_data = {}
|
||||
for field_item in change.field:
|
||||
for key, value in field_item.items():
|
||||
change_data[key] = value
|
||||
if isinstance(value, list):
|
||||
print(f" - {key}: {value[0]} -> {value[1]}")
|
||||
else:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
all_changes.append({
|
||||
'data': change_data
|
||||
})
|
||||
|
||||
# 测试字段映射
|
||||
try:
|
||||
mapped = map_field_names(change_data)
|
||||
print(f" 映射结果: {mapped}")
|
||||
except Exception as e:
|
||||
print(f" 映射失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理结果 {i + 1} 失败: {e}")
|
||||
|
||||
return all_changes
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
@@ -1,222 +1,88 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数:
|
||||
原本的输入句子:{{statement_databasets}}
|
||||
需要检测冲突对象:{{ evaluate_data }}
|
||||
冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID)
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false)
|
||||
|
||||
你的任务是:
|
||||
对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估)
|
||||
## 冲突定义
|
||||
# 记忆数据分析任务
|
||||
|
||||
## 输入数据
|
||||
- **原始句子**: {{statement_databasets}}
|
||||
- **检测对象**: {{ evaluate_data }}
|
||||
- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID)
|
||||
- **隐私审核**: {{ memory_verify }} (true/false)
|
||||
- **质量评估**: {{ quality_assessment }} (true/false)
|
||||
- **语言类型**:{{language_type}}(zh/en)
|
||||
## 任务目标
|
||||
对用户记忆数据进行冲突检测、隐私审核和质量评估,输出结构化JSON结果。
|
||||
**数据关系**: statement_databasets中的statement_id对应evaluate_data中的记录,代表句子拆分后的实体关系。
|
||||
## 1. 冲突检测
|
||||
### 时间冲突
|
||||
时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾:
|
||||
|
||||
1. **同一活动的时间冲突**:
|
||||
- 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球")
|
||||
- 同一用户在同一时间段内被记录进行不同的互斥活动
|
||||
|
||||
2. **时间逻辑错误**:
|
||||
- expired_at 早于 created_at
|
||||
- 同一事实的 created_at 时间差异超过合理误差范围(>5分钟)
|
||||
|
||||
3. **日期属性冲突**:
|
||||
- 同一人的生日记录为不同日期(如"2月10号"和"2月16号")
|
||||
4.存在明确先后约束 A -> B,但 t(A) > t(B)
|
||||
-例:入学时间晚于毕业时间。
|
||||
-处理:标记异常、降权、触发逻辑反思或人工审查。
|
||||
5.时间属性冲突
|
||||
-单值日期属性出现多值(生日、入职日期)
|
||||
-注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。
|
||||
6.互斥重叠冲突
|
||||
-例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地)
|
||||
-处理:证据仲裁、保留多版本(active + candidate)。
|
||||
|
||||
|
||||
|
||||
- **同一活动时间矛盾**: 同一用户同一活动的不同时间记录
|
||||
- **时间逻辑错误**: expired_at < created_at,created_at时间差>5分钟
|
||||
- **日期属性冲突**: 同一人的生日等单值属性出现多值
|
||||
- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业)
|
||||
- **互斥重叠**: 同一时间出现在不同地点等互斥事件
|
||||
### 事实冲突
|
||||
事实冲突是指同一实体的属性或关系存在相互矛盾的陈述:
|
||||
|
||||
1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
|
||||
### 混合冲突检测
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:
|
||||
检测任何逻辑上不一致或相互矛盾的记录
|
||||
## 记忆审核定义
|
||||
|
||||
### 隐私信息检测(隐私冲突)
|
||||
当memory_verify为true时,需要额外检测包含个人隐私信息的记录:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私检测原则
|
||||
- 检测description、entity1_name、entity2_name等字段中的隐私信息
|
||||
- 识别数字模式(如手机号11位数字、身份证18位等)
|
||||
- 识别关键词(如"身份证"、"银行卡"、"密码"等)
|
||||
- 检测敏感实体类型和关系
|
||||
|
||||
## 冲突检测原则
|
||||
|
||||
**全面检测**:不区分冲突类型,检测所有可能的冲突
|
||||
**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段
|
||||
**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录
|
||||
**语义分析**:分析description字段的语义相似性和冲突性
|
||||
**时间逻辑**:检查时间字段的逻辑一致性
|
||||
**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录
|
||||
|
||||
## 不符合冲突检测
|
||||
-称呼
|
||||
## 重要检测示例
|
||||
|
||||
### 冲突检测示例
|
||||
- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号)
|
||||
- 同一实体的重复定义但描述不同
|
||||
- 同一关系的不同表述但含义冲突
|
||||
- 任何逻辑上不可能同时为真的记录
|
||||
|
||||
### 隐私信息检测示例
|
||||
- 包含手机号的记录:"用户的手机号是13812345678"
|
||||
- 包含身份证的记录:"身份证号码为110101199001011234"
|
||||
- 包含银行卡的记录:"银行卡号6222021234567890"
|
||||
- 包含社交账号的记录:"微信号是user123456"
|
||||
- 包含敏感信息的实体名称或描述
|
||||
|
||||
## 输出要求
|
||||
|
||||
**关键原则**:
|
||||
1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录
|
||||
2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中
|
||||
3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中
|
||||
4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组
|
||||
5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null
|
||||
6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响
|
||||
7. 不输出conflict_memory字段
|
||||
|
||||
**处理逻辑**:
|
||||
- 首先进行冲突检测,将冲突记录加入data数组
|
||||
- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组
|
||||
- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果
|
||||
- 最终data数组包含所有冲突记录和隐私信息记录(去重)
|
||||
- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果
|
||||
- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
|
||||
## 记忆质量评估定义
|
||||
|
||||
### 质量评估标准
|
||||
当quality_assessment为true时,需要对记忆数据进行质量评估:
|
||||
|
||||
1. **数据完整性**:
|
||||
- 检查必要字段是否完整(entity1_name、entity2_name、description等)
|
||||
- 检查关系描述是否清晰明确
|
||||
- 检查时间字段的有效性
|
||||
|
||||
2. **重复字段检测**:
|
||||
- 识别相同或高度相似的记录
|
||||
- 检测冗余的实体关系
|
||||
- 分析描述内容的重复度
|
||||
|
||||
3. **无意义字段检测**:
|
||||
- 识别空值、无效值或占位符内容
|
||||
- 检测过于简单或无信息量的描述
|
||||
- 识别格式错误或不规范的数据
|
||||
|
||||
4. **上下文依赖性**:
|
||||
- 评估记录是否需要额外上下文才能理解
|
||||
- 检查实体名称的明确性
|
||||
- 分析关系描述的自包含性
|
||||
|
||||
### 质量评估输出
|
||||
- **质量百分比**:基于上述标准计算的整体质量分数(0-100)
|
||||
- **质量概述**:简要描述数据质量状况,包括主要问题和优点
|
||||
|
||||
输出是仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢)
|
||||
- **关系矛盾**: 同一实体在相同语境下的不同关系描述
|
||||
- **身份冲突**: 同一实体被赋予不同类型或角色
|
||||
### 混合冲突
|
||||
检测所有逻辑不一致或相互矛盾的记录。
|
||||
**检测原则**:
|
||||
- 重点检查相同实体的记录
|
||||
- 分析description字段语义冲突
|
||||
- 验证时间字段逻辑一致性
|
||||
## 2. 隐私审核 (memory_verify=true时)
|
||||
### 隐私信息类型
|
||||
- **身份信息**: 身份证号码、身份证相关描述
|
||||
- **联系方式**: 手机号、电话号码
|
||||
- **社交账号**: 微信号、QQ号、邮箱地址
|
||||
- **金融信息**: 银行卡号、账户信息、支付信息
|
||||
- **税务信息**: 税号、纳税信息、发票信息
|
||||
- **贷款信息**: 贷款记录、信贷信息
|
||||
- **安全信息**: 密码、PIN码、验证码
|
||||
### 检测方法
|
||||
- 检测description、entity1_name、entity2_name、name等字段
|
||||
- 识别数字模式(手机号11位、身份证18位等)
|
||||
- 识别关键词("身份证"、"银行卡"、"密码"等)
|
||||
## 3. 质量评估 (quality_assessment=true时)
|
||||
### 评估标准
|
||||
- **数据完整性**: 必要字段完整性、关系描述清晰度、时间字段有效性
|
||||
- **重复检测**: 相同或高度相似记录、冗余实体关系、描述重复度
|
||||
- **无意义检测**: 空值/无效值、过于简单的描述、格式错误
|
||||
- **上下文依赖**: 记录自包含性、实体名称明确性
|
||||
### 输出内容
|
||||
- **质量分数**: 0-100的整体质量百分比
|
||||
- **质量概述**: 简要描述数据质量状况和主要问题
|
||||
## 输出规则
|
||||
### 核心原则
|
||||
1. **conflict=true**: 存在冲突或隐私信息时,将所有相关记录放入data数组
|
||||
2. **conflict=false**: 无冲突且无隐私信息时,data为空数组
|
||||
3. **独立功能**: 冲突检测、隐私审核、质量评估三者完全独立
|
||||
4. **条件输出**:
|
||||
- quality_assessment=true时输出评估对象,否则为null
|
||||
- memory_verify=true时输出隐私检测对象,否则为null
|
||||
5. **不输出conflict_memory字段**
|
||||
### 处理流程
|
||||
1. 冲突检测 → 将冲突记录加入data
|
||||
2. 隐私审核(如启用) → 将隐私记录加入data
|
||||
3. 质量评估(如启用) → 独立输出评估结果
|
||||
4. 去重data数组中的记录
|
||||
**输出结构**:
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"entity1_name": "实体1名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间戳",
|
||||
"expired_at": "过期时间戳",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": "关系对象",
|
||||
"entity2_name": "实体2名称",
|
||||
"entity2": "实体2对象"
|
||||
}
|
||||
],
|
||||
"conflict": true或false,
|
||||
"data": [记录数组],
|
||||
"conflict": true/false,
|
||||
"quality_assessment": {
|
||||
"score": 质量百分比数字,
|
||||
"summary": "质量概述文本"
|
||||
"score": 数字,
|
||||
"summary": "文本"
|
||||
} 或 null,
|
||||
"memory_verify": {
|
||||
"has_privacy": true或false,
|
||||
"privacy_types": ["检测到的隐私信息类型列表"],
|
||||
"summary": "隐私检测结果概述"
|
||||
"has_privacy": true/false,
|
||||
"privacy_types": ["类型数组"],
|
||||
"summary": "概述文本"
|
||||
} 或 null
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。
|
||||
- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。
|
||||
- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。
|
||||
|
||||
### memory_verify字段说明
|
||||
当memory_verify为true时,需要输出隐私检测结果:
|
||||
- **has_privacy**: 布尔值,表示是否检测到隐私信息
|
||||
- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"])
|
||||
- **summary**: 字符串,简要描述隐私检测结果
|
||||
|
||||
当memory_verify为false时,memory_verify字段输出null。
|
||||
|
||||
### memory_verify字段示例
|
||||
|
||||
**示例1:检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": true,
|
||||
"privacy_types": ["手机号码", "身份证信息"],
|
||||
"summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码"
|
||||
}
|
||||
```
|
||||
|
||||
**示例2:未检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": false,
|
||||
"privacy_types": [],
|
||||
"summary": "未检测到隐私信息"
|
||||
}
|
||||
```
|
||||
|
||||
**示例3:memory_verify为false时**
|
||||
```json
|
||||
"memory_verify": null
|
||||
```
|
||||
|
||||
模式参考:
|
||||
{{ json_schema }}
|
||||
**字段说明**:
|
||||
- **data**: 包含冲突记录和隐私信息记录,无则为空数组
|
||||
- **quality_assessment**:
|
||||
quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
- **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
模式参考:{{ json_schema }}
|
||||
@@ -1,200 +1,155 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j)
|
||||
你将收到一条冲突判定对象:{{ data }}。
|
||||
需要检测冲突对象:{{ statement_databasets }}
|
||||
以及需要识别的冲突对象为:{{ baseline }}
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
# 记忆冲突解决任务
|
||||
|
||||
角色:
|
||||
- 你是数据领域中解决数据冲突的专家
|
||||
## 输入数据
|
||||
- **冲突数据**: {{ data }}
|
||||
- **原始句子**: {{ statement_databasets }}
|
||||
- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID)
|
||||
- **隐私审核**: {{ memory_verify }} (true/false)
|
||||
- **语言类型**:{{language_type}}(zh/en)
|
||||
|
||||
任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。
|
||||
## 任务目标
|
||||
作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。
|
||||
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间
|
||||
**数据关系**: statement_databasets中的statement_id对应data中的记录,statement_created_at为用户输入时间。
|
||||
|
||||
**处理模式**:
|
||||
- 当memory_verify为false时:仅处理数据冲突
|
||||
- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏
|
||||
**处理模式**:
|
||||
- memory_verify=false: 仅处理数据冲突
|
||||
- memory_verify=true: 处理数据冲突 + 隐私脱敏
|
||||
|
||||
## 分组处理原则
|
||||
## 1. 冲突类型定义
|
||||
|
||||
**冲突类型识别与分组**:
|
||||
1. **日期冲突**:
|
||||
1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号),
|
||||
1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球)
|
||||
3. **事实属性冲突**:
|
||||
3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3.3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别
|
||||
### 时间冲突 (TIME)
|
||||
时间维度冲突:两个事件时间重叠,或同一事情在不同时间场景下的变化。
|
||||
|
||||
**分组输出要求**:
|
||||
- 每种冲突类型生成一个独立的reflexion_result对象
|
||||
- 同一类型的多个冲突记录归并到一个结果中
|
||||
- 不同类型的冲突分别处理,各自生成独立结果
|
||||
### 事实冲突 (FACT)
|
||||
同一事实对象的陈述内容相互矛盾,真假不能共存的情况。
|
||||
|
||||
## 冲突类型定义
|
||||
### 混合冲突 (HYBRID)
|
||||
检测所有类型冲突,包括时间和事实冲突的任何逻辑不一致记录。
|
||||
|
||||
### 时间冲突(TIME)
|
||||
时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。
|
||||
## 2. 分组处理原则
|
||||
|
||||
### 事实冲突(FACT)
|
||||
事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。
|
||||
### 混合冲突(HYBRID)
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录
|
||||
{% if memory_verify %}
|
||||
## 隐私信息处理(memory_verify为true时启用)
|
||||
### 冲突类型识别
|
||||
- **日期冲突**: 用户生日不同日期(2月10号 vs 2月16号)、同一活动不同时间(周五 vs 周六打球)
|
||||
- **事实属性冲突**:
|
||||
- 属性互斥(喜欢↔不喜欢)
|
||||
- 关系矛盾(同一实体不同关系描述)
|
||||
- 身份冲突(同一实体不同类型/角色)
|
||||
- **其他/混合冲突**: 根据具体数据识别
|
||||
|
||||
### 隐私信息识别
|
||||
需要识别并处理以下类型的隐私信息:
|
||||
### 分组输出要求
|
||||
- 每种冲突类型生成独立的reflexion_result对象
|
||||
- 同类型多个冲突归并到一个结果
|
||||
- 不同类型分别处理,各自生成独立结果
|
||||
## 3. 隐私信息处理 (memory_verify=true时)
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
### 隐私信息类型
|
||||
- **身份信息**: 身份证号码、身份证相关描述
|
||||
- **联系方式**: 手机号、电话号码
|
||||
- **社交账号**: 微信号、QQ号、邮箱地址
|
||||
- **金融信息**: 银行卡号、账户信息、支付信息
|
||||
- **税务信息**: 税号、纳税信息、发票信息
|
||||
- **贷款信息**: 贷款记录、信贷信息
|
||||
- **安全信息**: 密码、PIN码、验证码
|
||||
|
||||
### 隐私数据脱敏规则
|
||||
对于检测到的隐私信息,按以下规则进行脱敏处理:
|
||||
### 脱敏规则
|
||||
**数字类**: 保留前三位和后四位,中间用*代替
|
||||
- 手机号: 13812345678 → 138****5678
|
||||
- 身份证: 110101199001011234 → 110***********1234
|
||||
- 银行卡: 6222021234567890 → 622***********7890
|
||||
|
||||
**数字类隐私信息脱敏**:
|
||||
- 保留前三位和后四位,中间用*代替
|
||||
- 示例:手机号13812345678 → 138****5678
|
||||
- 示例:身份证110101199001011234 → 110***********1234
|
||||
- 示例:银行卡6222021234567890 → 622***********7890
|
||||
**文本类**: 保留前三后四位字符,中间用*代替
|
||||
- 微信号: user123456 → use****3456
|
||||
- 邮箱: zhang.san@example.com → zha****@example.com
|
||||
|
||||
**文本类隐私信息脱敏**:
|
||||
- 社交账号:保留前三后四位字符,中间用*代替
|
||||
- 示例:微信号user123456 → use****3456
|
||||
- 示例:邮箱zhang.san@example.com → zha****@example.com
|
||||
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
||||
|
||||
**脱敏处理字段**:
|
||||
- name字段:如包含隐私信息需脱敏
|
||||
- entity1_name字段:如包含隐私信息需脱敏
|
||||
- entity2_name字段:如包含隐私信息需脱敏
|
||||
- description字段:如包含隐私信息需脱敏
|
||||
{% endif %}
|
||||
## 4. 处理流程
|
||||
|
||||
## 工作步骤
|
||||
### 步骤1: 类型匹配验证
|
||||
**匹配规则**:
|
||||
- baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点)
|
||||
- baseline="FACT": 只处理事实相关冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- baseline="HYBRID": 处理所有类型冲突
|
||||
|
||||
### 第一步:分析冲突类型匹配
|
||||
首先判断输入的冲突数据是否符合baseline要求的类型:
|
||||
**类型识别**:
|
||||
- 时间冲突: entity2的entity_type包含"TimeExpression"/"TemporalExpression",或entity2_name包含时间词汇
|
||||
- 事实冲突: 相同实体的不同属性描述、互斥关系陈述
|
||||
|
||||
**类型匹配规则**:
|
||||
- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突)
|
||||
- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理
|
||||
**重要**: 类型不匹配时必须输出空结果(resolved为null)
|
||||
|
||||
**类型识别**:
|
||||
- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等)
|
||||
- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述
|
||||
### 步骤2: 冲突数据分组
|
||||
**分组策略**:
|
||||
- 时间冲突组: 涉及用户时间的记录
|
||||
- 活动时间冲突组: 同一活动不同时间的记录
|
||||
- 事实冲突组: 同一实体不同属性的记录
|
||||
- 其他冲突组: 其他类型冲突记录
|
||||
|
||||
**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null)
|
||||
**筛选条件**: 只处理与baseline匹配的冲突类型
|
||||
|
||||
### 第二步:筛选并分组冲突数据
|
||||
按冲突类型对数据进行分组:
|
||||
### 步骤3: 冲突解决策略
|
||||
**重要**: 数据被判定为正确时不可修改
|
||||
|
||||
**分组策略**:
|
||||
1. **时间冲突组**:筛选涉及用户时间的所有记录
|
||||
2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录
|
||||
3. **事实冲突组**:筛选涉及同一实体不同属性的记录
|
||||
4. **其他冲突组**:其他类型的冲突记录
|
||||
**智能解决**:
|
||||
1. 分析冲突数据,结合statement_databasets原文判定正确性
|
||||
2. 判断正确答案是否存在于data中
|
||||
3. 根据情况选择处理方式{% if memory_verify %}
|
||||
4. 隐私脱敏处理在冲突解决后进行{% endif %}
|
||||
|
||||
**筛选条件**:
|
||||
- 只处理与baseline匹配的冲突类型
|
||||
- 相同entity1_name但entity2_name不同的记录
|
||||
- 相同关系但描述矛盾的记录
|
||||
- 时间逻辑不一致的记录
|
||||
### 处理规则
|
||||
|
||||
### 第三步:冲突解决策略
|
||||
** 不可以解决的冲突情况
|
||||
1. 数据被判定为正确的情况下,不可以进行修改
|
||||
**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理:
|
||||
** baseline是TIME
|
||||
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
|
||||
** baseline不是TIME
|
||||
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
||||
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
||||
|
||||
**智能解决策略**:
|
||||
1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定
|
||||
2. **判断正确答案是否存在**:
|
||||
- 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00)
|
||||
- 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改
|
||||
- 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息
|
||||
**核心原则**:
|
||||
- 只输出需要修改的记录
|
||||
- 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录
|
||||
- 精确记录变更: change字段包含记录ID、字段名称、新值和旧值{% if memory_verify %}
|
||||
- 隐私保护优先: 所有输出记录必须完成隐私脱敏
|
||||
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
||||
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容
|
||||
|
||||
{% if memory_verify %}
|
||||
**隐私处理集成**:
|
||||
- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏
|
||||
- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏
|
||||
- 在change字段中记录隐私脱敏的变更
|
||||
{% endif %}
|
||||
|
||||
**具体处理规则**:
|
||||
|
||||
**情况1:正确答案存在于data中**
|
||||
- 保留正确的记录不变
|
||||
- 基于时间关系的冲突:
|
||||
需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
||||
- 基于事实的关系冲突
|
||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
||||
- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间)
|
||||
|
||||
**情况2:正确答案不存在于data中**
|
||||
- 选择最合适的记录进行修改
|
||||
- 更新该记录的相关字段:
|
||||
- description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %}
|
||||
- name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %}
|
||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
||||
- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]`
|
||||
|
||||
**重要原则**:
|
||||
- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据
|
||||
- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录
|
||||
- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值
|
||||
{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理
|
||||
- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %}
|
||||
- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空
|
||||
|
||||
**变更记录格式**:
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
{"id":修改字段对应的ID}
|
||||
{"statement_id":需要修改的对象对应的statement_id}
|
||||
{"字段名1": ["修改前的值1","修改后的值1"]},
|
||||
{"字段名2": ["修改前的值2","修改后的值2"]}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**类型不匹配处理**:
|
||||
- 如果冲突类型与baseline不匹配,resolved必须设为null
|
||||
- reflexion.reason说明类型不匹配的原因
|
||||
**类型不匹配处理**:
|
||||
- 冲突类型与baseline不匹配时,resolved设为null
|
||||
- reflexion.reason说明类型不匹配原因
|
||||
- reflexion.solution说明无需处理
|
||||
|
||||
### 第四步:输出解决方案
|
||||
## 5. JSON输出格式
|
||||
|
||||
## 输出要求
|
||||
**嵌套字段映射**(系统会自动处理):
|
||||
**格式要求**:
|
||||
- 输出有效JSON对象,通过json.loads()解析
|
||||
- 使用标准ASCII双引号(")
|
||||
- 内部引号用反斜杠转义(\")
|
||||
- 字符串值不包含换行符
|
||||
- 不输出```json```等代码块标记
|
||||
|
||||
**嵌套字段映射**(系统自动处理):
|
||||
- `entity2.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `relationship` → 自动映射为 `statement`
|
||||
- `entity1.description` → 自动映射为 `description`
|
||||
- `entity2.description` → 自动映射为 `description`
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 响应必须是与此确切模式匹配的有效JSON对象
|
||||
- 不要在JSON之前或之后包含任何文本
|
||||
|
||||
JSON格式要求:
|
||||
1. JSON结构仅使用标准ASCII双引号(")
|
||||
2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义
|
||||
3. 确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4. JSON字符串值中不包括换行符
|
||||
5. 不允许输出```json```相关符号
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
|
||||
**输出格式:按冲突类型分组的列表**
|
||||
**输出结构**: 按冲突类型分组的列表
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
@@ -208,93 +163,24 @@ JSON格式要求:
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "被设为失效的记忆id",
|
||||
"resolved_memory": {
|
||||
"entity1_name": "实体1名称",
|
||||
"entity2_name": "实体2名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间",
|
||||
"expired_at": "过期时间",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": {},
|
||||
"entity2": {...}
|
||||
},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
"resolved_memory": {记录对象},
|
||||
"change": [变更记录数组]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**示例:多种冲突类型的输出**
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [生日冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期",
|
||||
"solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "df066210883545a08e727ccd8ad4ec77",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"expired_at": "2025-12-16T12:00:00"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
},
|
||||
{
|
||||
"conflict": {
|
||||
"data": [篮球时间冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突",
|
||||
"solution": "保留最可信的时间记录,将冲突记录设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "另一个记录ID",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"description": "使用系统的个人,指代说话者本人,篮球时间为周六"},
|
||||
{"entity2_name": "周六"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本
|
||||
- 使用标准双引号,必要时对内部引号进行转义
|
||||
- 字段名与结构必须与给定模式一致
|
||||
- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象
|
||||
- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中
|
||||
- **每个result对象的conflict.data**只包含该冲突类型相关的记录
|
||||
- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出
|
||||
- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值
|
||||
- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null
|
||||
- 如果与baseline不匹配的冲突类型,不要在results中包含该类型
|
||||
|
||||
模式参考:
|
||||
{{ json_schema }}
|
||||
**输出要求**:
|
||||
- 只输出JSON,不添加解释文本
|
||||
- 使用标准双引号,必要时转义
|
||||
- 字段名与结构必须与模式一致
|
||||
- **results数组格式**: 每个冲突类型作为独立对象
|
||||
- **按冲突类型分组**: 相同类型冲突归并到一个result对象
|
||||
- **conflict.data**: 只包含该冲突类型相关记录
|
||||
- **resolved.resolved_memory**: 只包含需要修改的记录
|
||||
- **resolved.change**: 包含详细变更信息
|
||||
- 无需修改的冲突类型resolved为null
|
||||
- 与baseline不匹配的冲突类型不包含在results中
|
||||
模式参考: {{ json_schema }}
|
||||
@@ -9,7 +9,8 @@ prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -30,12 +31,13 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
|
||||
baseline=baseline,
|
||||
memory_verify=memory_verify,
|
||||
quality_assessment=quality_assessment,
|
||||
statement_databasets=statement_databasets
|
||||
statement_databasets=statement_databasets,
|
||||
language_type=language_type
|
||||
)
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = []) -> str:
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -51,6 +53,6 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any],
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets)
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -69,7 +69,17 @@ class WorkflowExecutor:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_vars = input_data.get("conversation_vars") or {}
|
||||
|
||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
conversation_vars = {}
|
||||
for var_def in config_variables_list:
|
||||
if isinstance(var_def, dict):
|
||||
var_name = var_def.get("name")
|
||||
var_default = var_def.get("default")
|
||||
if var_name:
|
||||
conversation_vars[var_name] = var_default
|
||||
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
# 构建分层的变量结构
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
@@ -23,5 +25,7 @@ __all__ = [
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
"WorkflowNode"
|
||||
"WorkflowNode",
|
||||
# "KnowledgeRetrievalNode",
|
||||
"AssignerNode",
|
||||
]
|
||||
|
||||
4
api/app/core/workflow/nodes/assigner/__init__.py
Normal file
4
api/app/core/workflow/nodes/assigner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.assigner.node import AssignerNode
|
||||
|
||||
__all__ = ["AssignerNode", "AssignerNodeConfig"]
|
||||
21
api/app/core/workflow/nodes/assigner/config.py
Normal file
21
api/app/core/workflow/nodes/assigner/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
|
||||
|
||||
class AssignerNodeConfig(BaseNodeConfig):
|
||||
variable_selector: str | list[str] = Field(
|
||||
...,
|
||||
description="Variables to be assigned",
|
||||
)
|
||||
|
||||
operation: AssignmentOperator = Field(
|
||||
...,
|
||||
description="Operator to assign",
|
||||
)
|
||||
|
||||
value: str | list[str] = Field(
|
||||
...,
|
||||
description="Values to assign",
|
||||
)
|
||||
80
api/app/core/workflow/nodes/assigner/node.py
Normal file
80
api/app/core/workflow/nodes/assigner/node.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
|
||||
Args:
|
||||
state: The current workflow state, including conversation variables,
|
||||
node outputs, and system variables.
|
||||
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
pool = VariablePool(state)
|
||||
|
||||
# Get the target variable selector (e.g., "conv.test")
|
||||
variable_selector = self.typed_config.variable_selector
|
||||
if isinstance(variable_selector, str):
|
||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||
variable_selector = variable_selector.split('.')
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||
raise ValueError("Only conversation variables can be assigned.")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = self.typed_config.value
|
||||
if isinstance(value, list):
|
||||
value = '.'.join(value)
|
||||
value = ExpressionEvaluator.evaluate(
|
||||
expression=value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||
pool, variable_selector, value
|
||||
)
|
||||
|
||||
# Execute the configured assignment operation
|
||||
match self.typed_config.operation:
|
||||
case AssignmentOperator.ASSIGN:
|
||||
operator.assign()
|
||||
case AssignmentOperator.CLEAR:
|
||||
operator.clear()
|
||||
case AssignmentOperator.ADD:
|
||||
operator.add()
|
||||
case AssignmentOperator.SUBTRACT:
|
||||
operator.subtract()
|
||||
case AssignmentOperator.MULTIPLY:
|
||||
operator.multiply()
|
||||
case AssignmentOperator.DIVIDE:
|
||||
operator.divide()
|
||||
case AssignmentOperator.APPEND:
|
||||
operator.append()
|
||||
case AssignmentOperator.REMOVE_FIRST:
|
||||
operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")
|
||||
@@ -26,7 +26,12 @@ class WorkflowState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
|
||||
# 输入变量(从配置的 variables 传入)
|
||||
variables: dict[str, Any]
|
||||
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx)
|
||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||
**x,
|
||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||
for k, v in y.items()}
|
||||
}]
|
||||
|
||||
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||
@@ -544,9 +549,15 @@ class BaseNode(ABC):
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
@@ -575,9 +586,15 @@ class BaseNode(ABC):
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构(包含 sys 和 conv)
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
@@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -28,4 +30,6 @@ __all__ = [
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
# "KnowledgeRetrievalNodeConfig",
|
||||
"AssignerNodeConfig",
|
||||
]
|
||||
|
||||
@@ -33,7 +33,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
@@ -45,17 +45,17 @@ class EndNode(BaseNode):
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
@@ -63,44 +63,44 @@ class EndNode(BaseNode):
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
@@ -115,60 +115,62 @@ class EndNode(BaseNode):
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
|
||||
last_end = end
|
||||
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Yields:
|
||||
完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
|
||||
# 解析模板部分
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
@@ -176,12 +178,12 @@ class EndNode(BaseNode):
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,输出完整模板内容
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
|
||||
# 通过 writer 发送完整内容(作为一个 message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
@@ -194,57 +196,56 @@ class EndNode(BaseNode):
|
||||
"is_suffix": False
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
|
||||
|
||||
# 收集后缀部分
|
||||
suffix_parts = []
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}")
|
||||
for i in range(upstream_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
|
||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# 其他动态引用(如果有多个引用)
|
||||
# Other dynamic references (if there are multiple references)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# 从 streaming_buffer 或 node_outputs 读取
|
||||
streaming_buffer = state.get("streaming_buffer", {})
|
||||
if node_id in streaming_buffer:
|
||||
buffer_data = streaming_buffer[node_id]
|
||||
content = buffer_data.get("full_content", "")
|
||||
else:
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
runtime_vars = state.get("runtime_vars", {})
|
||||
|
||||
# Use VariablePool to get variable value
|
||||
pool = self.get_variable_pool(state)
|
||||
try:
|
||||
# Try to get variable value with default empty string
|
||||
content = pool.get([node_id, field], default="")
|
||||
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
|
||||
content = ""
|
||||
if node_id in node_outputs:
|
||||
node_output = node_outputs[node_id]
|
||||
if isinstance(node_output, dict):
|
||||
content = str(node_output.get(field, ""))
|
||||
elif node_id in runtime_vars:
|
||||
runtime_var = runtime_vars[node_id]
|
||||
if isinstance(runtime_var, dict):
|
||||
content = str(runtime_var.get(field, ""))
|
||||
|
||||
suffix_parts.append(content)
|
||||
|
||||
# Convert to string if not None
|
||||
suffix_parts.append(str(content) if content is not None else "")
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
@@ -260,13 +261,13 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.operators import (
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
AssignmentOperatorType,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
)
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
@@ -14,6 +23,7 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
ASSIGNER = "assigner"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
@@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum):
|
||||
class LogicOperator(StrEnum):
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
|
||||
|
||||
class AssignmentOperator(StrEnum):
|
||||
ASSIGN = "assign"
|
||||
CLEAR = "clear"
|
||||
|
||||
ADD = "add" # +=
|
||||
SUBTRACT = "subtract" # -=
|
||||
MULTIPLY = "multiply" # *=
|
||||
DIVIDE = "divide" # /=
|
||||
|
||||
APPEND = "append"
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
|
||||
@classmethod
|
||||
def get_operator(cls, obj) -> AssignmentOperatorType:
|
||||
if isinstance(obj, str):
|
||||
return StringOperator
|
||||
elif isinstance(obj, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(obj, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(obj, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(obj, dict):
|
||||
return ObjectOperator
|
||||
|
||||
raise TypeError(f"Unsupported variable type ({type(obj)})")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||
|
||||
@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
type=model_type
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,6 +28,8 @@ WorkflowNode = Union[
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
# KnowledgeRetrievalNode,
|
||||
]
|
||||
|
||||
|
||||
@@ -42,7 +46,9 @@ class NodeFactory:
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -82,10 +88,6 @@ class NodeFactory:
|
||||
"""
|
||||
node_type = node_config.get("type")
|
||||
|
||||
# 跳过条件节点(由 LangGraph 处理)
|
||||
if node_type == "condition":
|
||||
return None
|
||||
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
|
||||
146
api/app/core/workflow/nodes/operators.py
Normal file
146
api/app/core/workflow/nodes/operators.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from abc import ABC
|
||||
from typing import Union, Type
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
class OperatorBase(ABC):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
self.pool = pool
|
||||
self.left_selector = left_selector
|
||||
self.right = right
|
||||
|
||||
self.type_limit: type[str, int, dict, list] = None
|
||||
|
||||
def check(self, no_right=False):
|
||||
left = self.pool.get(self.left_selector)
|
||||
if not isinstance(left, self.type_limit):
|
||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||
|
||||
if not no_right and not isinstance(self.right, self.type_limit):
|
||||
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
|
||||
|
||||
|
||||
class StringOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = str
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, '')
|
||||
|
||||
|
||||
class NumberOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = (float, int)
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, 0)
|
||||
|
||||
def add(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin + self.right)
|
||||
|
||||
def subtract(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin - self.right)
|
||||
|
||||
def multiply(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin * self.right)
|
||||
|
||||
def divide(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin / self.right)
|
||||
|
||||
|
||||
class BooleanOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = bool
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, False)
|
||||
|
||||
|
||||
class ArrayOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = list
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, list())
|
||||
|
||||
def append(self) -> None:
|
||||
self.check(no_right=True)
|
||||
# TODO:require type limit in list
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.append(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
def extend(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.extend(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_last(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.pop()
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_first(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.pop(0)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
|
||||
class ObjectOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = object
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
AssignmentOperatorInstance = Union[
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
]
|
||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||
@@ -66,19 +66,26 @@ class TemplateRenderer:
|
||||
'分析结果: 正面情绪'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
|
||||
context = {
|
||||
"var": variables, # 用户变量:{{var.user_input}}
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
|
||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
context.update(node_outputs)
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
# 为了向后兼容,也支持直接访问用户变量
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs # 旧语法兼容
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
|
||||
@@ -10,7 +10,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -82,7 +85,7 @@ class VariablePool:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""
|
||||
|
||||
def __init__(self, state: dict[str, Any]):
|
||||
def __init__(self, state: "WorkflowState"):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user