Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
85
api/app/core/memory/models/emotion_models.py
Normal file
85
api/app/core/memory/models/emotion_models.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Emotion extraction models for LLM structured output.
|
||||
|
||||
This module contains Pydantic models for emotion extraction from statements,
|
||||
designed to be used with LLM structured output capabilities.
|
||||
|
||||
Classes:
|
||||
EmotionExtraction: Model for emotion extraction results from statements
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmotionExtraction(BaseModel):
|
||||
"""Emotion extraction result model for LLM structured output.
|
||||
|
||||
This model represents the structured emotion information extracted from
|
||||
a statement using LLM. It includes emotion type, intensity, keywords,
|
||||
subject classification, and optional target.
|
||||
|
||||
Attributes:
|
||||
emotion_type: Type of emotion (joy/sadness/anger/fear/surprise/neutral)
|
||||
emotion_intensity: Intensity of emotion (0.0-1.0)
|
||||
emotion_keywords: List of emotion keywords from the statement (max 3)
|
||||
emotion_subject: Subject of emotion (self/other/object)
|
||||
emotion_target: Optional target of emotion (person or object name)
|
||||
"""
|
||||
|
||||
emotion_type: str = Field(
|
||||
...,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_intensity: float = Field(
|
||||
...,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity from 0.0 to 1.0"
|
||||
)
|
||||
emotion_keywords: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords extracted from the statement (max 3)"
|
||||
)
|
||||
emotion_subject: str = Field(
|
||||
...,
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_target: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion target: person or object name"
|
||||
)
|
||||
|
||||
@field_validator('emotion_type')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
"""Validate emotion type is one of the valid values."""
|
||||
valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_subject')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
"""Validate emotion subject is one of the valid values."""
|
||||
valid_subjects = ['self', 'other', 'object']
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_keywords')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
"""Validate and limit emotion keywords to max 3 items."""
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
# Limit to max 3 keywords
|
||||
return v[:3]
|
||||
|
||||
@field_validator('emotion_intensity')
|
||||
@classmethod
|
||||
def validate_emotion_intensity(cls, v):
|
||||
"""Validate emotion intensity is within valid range."""
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"emotion_intensity must be between 0.0 and 1.0, got {v}")
|
||||
return v
|
||||
@@ -215,24 +215,58 @@ class StatementNode(Node):
|
||||
Attributes:
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
statement: The actual statement text content
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
|
||||
emotion_target: Optional emotion target (person or object name)
|
||||
emotion_subject: Optional emotion subject (self/other/object)
|
||||
emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral)
|
||||
emotion_keywords: Optional list of emotion keywords (max 3)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
valid_at: Optional start date of temporal validity
|
||||
invalid_at: Optional end date of temporal validity
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
chunk_embedding: Optional embedding vector for the parent chunk
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
config_id: Configuration ID used to process this statement
|
||||
"""
|
||||
# Core fields (ordered as requested)
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity: 0.0-1.0 (displayed on node)"
|
||||
)
|
||||
emotion_target: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion target: person or object name"
|
||||
)
|
||||
emotion_subject: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_keywords: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords list, max 3 items"
|
||||
)
|
||||
|
||||
# Temporal fields
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@@ -240,6 +274,39 @@ class StatementNode(Node):
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
@field_validator('emotion_type', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
"""Validate emotion type is one of the valid values"""
|
||||
if v is None:
|
||||
return v
|
||||
valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_subject', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
"""Validate emotion subject is one of the valid values"""
|
||||
if v is None:
|
||||
return v
|
||||
valid_subjects = ['self', 'other', 'object']
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_keywords', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
"""Validate emotion keywords list has max 3 items"""
|
||||
if v is None:
|
||||
return []
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
# Limit to max 3 keywords
|
||||
return v[:3]
|
||||
|
||||
|
||||
class ChunkNode(Node):
|
||||
|
||||
@@ -64,6 +64,11 @@ class Statement(BaseModel):
|
||||
connect_strength: Optional connection strength ('Strong' or 'Weak')
|
||||
temporal_validity: Optional temporal validity range
|
||||
triplet_extraction_info: Optional triplet extraction results
|
||||
emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral)
|
||||
emotion_intensity: Optional emotion intensity (0.0-1.0)
|
||||
emotion_keywords: Optional list of emotion keywords
|
||||
emotion_subject: Optional emotion subject (self/other/object)
|
||||
emotion_target: Optional emotion target (person or object name)
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
@@ -80,6 +85,12 @@ class Statement(BaseModel):
|
||||
triplet_extraction_info: Optional[TripletExtractionResponse] = Field(
|
||||
None, description="The triplet extraction information of the statement."
|
||||
)
|
||||
# Emotion fields
|
||||
emotion_type: Optional[str] = Field(None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral")
|
||||
emotion_intensity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0")
|
||||
emotion_keywords: Optional[List[str]] = Field(default_factory=list, description="Emotion keywords, max 3")
|
||||
emotion_subject: Optional[str] = Field(None, description="Emotion subject: self/other/object")
|
||||
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
|
||||
@@ -480,7 +480,6 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||
- records: textual logs including per-round/per-block summaries and per-pair decisions
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
# 初始化全局日志和全局ID映射(存储所有轮次的结果)
|
||||
records: List[str] = []
|
||||
|
||||
@@ -36,7 +36,6 @@ from app.core.memory.models.graph_models import (
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
@@ -182,11 +181,12 @@ class ExtractionOrchestrator:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||
(
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -209,78 +209,13 @@ class ExtractionOrchestrator:
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||
|
||||
# 进度回调:按三个阶段分别输出知识抽取结果
|
||||
if self.progress_callback:
|
||||
# 第一阶段:陈述句提取结果
|
||||
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement_index": i + 1,
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
|
||||
|
||||
# 第二阶段:三元组提取结果
|
||||
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
|
||||
triplet_result = {
|
||||
"extraction_type": "triplet",
|
||||
"triplet_index": i + 1,
|
||||
"subject": triplet.subject_name,
|
||||
"predicate": triplet.predicate,
|
||||
"object": triplet.object_name
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
|
||||
|
||||
# 第三阶段:时间提取结果
|
||||
if total_temporal > 0:
|
||||
# 收集时间信息
|
||||
temporal_results = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
temporal_results.append({
|
||||
"statement_id": statement.id,
|
||||
"statement": statement.statement,
|
||||
"valid_at": statement.temporal_validity.valid_at,
|
||||
"invalid_at": statement.temporal_validity.invalid_at
|
||||
})
|
||||
|
||||
# 输出时间提取结果
|
||||
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": i + 1,
|
||||
"statement": temporal_result["statement"],
|
||||
"valid_at": temporal_result["valid_at"],
|
||||
"invalid_at": temporal_result["invalid_at"]
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
else:
|
||||
# 如果没有时间信息,也发送一个时间提取完成的消息
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": 0,
|
||||
"message": "未发现时间信息"
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
|
||||
# 进度回调:知识抽取完成,传递知识抽取的统计信息
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": total_entities,
|
||||
"triplets_count": total_triplets,
|
||||
"temporal_ranges_count": total_temporal,
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 步骤 4: 将提取的数据赋值到语句
|
||||
logger.info("步骤 4/6: 数据赋值")
|
||||
dialog_data_list = await self._assign_extracted_data(
|
||||
dialog_data_list,
|
||||
temporal_maps,
|
||||
triplet_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -288,6 +223,9 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 步骤 5: 创建节点和边
|
||||
logger.info("步骤 5/6: 创建节点和边")
|
||||
|
||||
# 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送
|
||||
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -307,6 +245,8 @@ class ExtractionOrchestrator:
|
||||
else:
|
||||
logger.info("步骤 6/6: 两阶段去重和消歧")
|
||||
|
||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||
|
||||
result = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -331,7 +271,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[DialogData]:
|
||||
"""
|
||||
从对话中提取陈述句(优化版:全局分块级并行)
|
||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -339,7 +279,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
更新后的对话数据列表(包含提取的陈述句)
|
||||
"""
|
||||
logger.info("开始陈述句提取(全局分块级并行)")
|
||||
logger.info("开始陈述句提取(全局分块级并行 + 流式输出)")
|
||||
|
||||
# 收集所有分块及其元数据
|
||||
all_chunks = []
|
||||
@@ -352,17 +292,44 @@ class ExtractionOrchestrator:
|
||||
chunk_metadata.append((d_idx, c_idx))
|
||||
|
||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||
|
||||
# 用于跟踪已完成的分块数量
|
||||
completed_chunks = 0
|
||||
total_chunks = len(all_chunks)
|
||||
|
||||
# 全局并行处理所有分块
|
||||
async def extract_for_chunk(chunk_data):
|
||||
async def extract_for_chunk(chunk_data, chunk_index):
|
||||
nonlocal completed_chunks
|
||||
chunk, group_id, dialogue_content = chunk_data
|
||||
try:
|
||||
return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
|
||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||
completed_chunks += 1
|
||||
if self.progress_callback and statements and self.is_pilot_run:
|
||||
# 发送前3个陈述句作为示例
|
||||
for idx, stmt in enumerate(statements[:3]):
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id,
|
||||
"chunk_progress": f"{completed_chunks}/{total_chunks}",
|
||||
"statement_index_in_chunk": idx + 1
|
||||
}
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_result",
|
||||
f"陈述句提取中 ({completed_chunks}/{total_chunks})",
|
||||
stmt_result
|
||||
)
|
||||
|
||||
return statements
|
||||
except Exception as e:
|
||||
logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}")
|
||||
completed_chunks += 1
|
||||
return []
|
||||
|
||||
tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks]
|
||||
tasks = [extract_for_chunk(chunk_data, i) for i, chunk_data in enumerate(all_chunks)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果分配回对话
|
||||
@@ -394,7 +361,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取三元组(优化版:全局陈述句级并行)
|
||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -402,7 +369,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
三元组映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始三元组提取(全局陈述句级并行)")
|
||||
logger.info("开始三元组提取(全局陈述句级并行 + 流式输出)")
|
||||
|
||||
# 收集所有陈述句及其元数据
|
||||
all_statements = []
|
||||
@@ -415,20 +382,32 @@ class ExtractionOrchestrator:
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组")
|
||||
|
||||
# 用于跟踪已完成的陈述句数量
|
||||
completed_statements = 0
|
||||
total_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
nonlocal completed_statements
|
||||
statement, chunk_content = stmt_data
|
||||
try:
|
||||
return await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
|
||||
# 注意:不再发送三元组提取的流式输出
|
||||
# 三元组提取在后台执行,但不向前端发送详细信息
|
||||
completed_statements += 1
|
||||
|
||||
return triplet_info
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
|
||||
completed_statements += 1
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
@@ -465,7 +444,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取时间信息(优化版:全局陈述句级并行)
|
||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -473,7 +452,21 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
时间信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始时间信息提取(全局陈述句级并行)")
|
||||
# 试运行模式:跳过时间提取以节省时间
|
||||
if self.is_pilot_run:
|
||||
logger.info("试运行模式:跳过时间信息提取(节省约 10-15 秒)")
|
||||
# 为所有陈述句返回空的时间范围
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
temporal_maps = []
|
||||
for dialog in dialog_data_list:
|
||||
temporal_map = {}
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
temporal_maps.append(temporal_map)
|
||||
return temporal_maps
|
||||
|
||||
logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)")
|
||||
|
||||
# 收集所有需要提取时间的陈述句
|
||||
all_statements = []
|
||||
@@ -501,18 +494,30 @@ class ExtractionOrchestrator:
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取")
|
||||
|
||||
# 用于跟踪已完成的时间提取数量
|
||||
completed_temporal = 0
|
||||
total_temporal_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
nonlocal completed_temporal
|
||||
statement, ref_dates = stmt_data
|
||||
try:
|
||||
return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
|
||||
temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
|
||||
|
||||
# 注意:不再发送时间提取的流式输出
|
||||
# 时间提取在后台执行,但不向前端发送详细信息
|
||||
completed_temporal += 1
|
||||
|
||||
return temporal_range
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}")
|
||||
completed_temporal += 1
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
return TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
@@ -542,9 +547,108 @@ class ExtractionOrchestrator:
|
||||
|
||||
return temporal_maps
|
||||
|
||||
async def _extract_emotions(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取情绪信息(优化版:全局陈述句级并行)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
情绪信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始情绪信息提取(全局陈述句级并行)")
|
||||
|
||||
# 收集所有陈述句及其配置
|
||||
all_statements = []
|
||||
statement_metadata = [] # (dialog_idx, statement_id)
|
||||
|
||||
# 获取第一个对话的config_id来加载配置
|
||||
config_id = None
|
||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||
config_id = dialog_data_list[0].config_id
|
||||
|
||||
# 加载DataConfig
|
||||
data_config = None
|
||||
if config_id:
|
||||
try:
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if data_config and not data_config.emotion_enabled:
|
||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
else:
|
||||
logger.info("未找到config_id,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 如果配置未启用情绪提取,直接返回空映射
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 收集所有陈述句
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
|
||||
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
statement, config = stmt_data
|
||||
try:
|
||||
return await emotion_service.extract_emotion(statement.statement, config)
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}")
|
||||
return None
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
emotion_maps = [{} for _ in dialog_data_list]
|
||||
successful_extractions = 0
|
||||
|
||||
for i, result in enumerate(results):
|
||||
d_idx, stmt_id = statement_metadata[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"陈述句处理异常: {result}")
|
||||
emotion_maps[d_idx][stmt_id] = None
|
||||
else:
|
||||
emotion_maps[d_idx][stmt_id] = result
|
||||
if result is not None:
|
||||
successful_extractions += 1
|
||||
|
||||
# 统计提取结果
|
||||
logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪")
|
||||
|
||||
return emotion_maps
|
||||
|
||||
async def _parallel_extract_and_embed(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> Tuple[
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, List[float]]],
|
||||
@@ -552,35 +656,39 @@ class ExtractionOrchestrator:
|
||||
List[List[float]],
|
||||
]:
|
||||
"""
|
||||
并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
|
||||
这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
|
||||
这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
|
||||
- 三元组提取:从陈述句中提取实体和关系
|
||||
- 时间信息提取:从陈述句中提取时间范围
|
||||
- 情绪提取:从陈述句中提取情绪信息
|
||||
- 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
五个列表的元组:
|
||||
六个列表的元组:
|
||||
- 三元组映射列表
|
||||
- 时间信息映射列表
|
||||
- 情绪映射列表
|
||||
- 陈述句嵌入映射列表
|
||||
- 分块嵌入映射列表
|
||||
- 对话嵌入列表
|
||||
"""
|
||||
logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成")
|
||||
logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成")
|
||||
|
||||
# 创建三个并行任务
|
||||
# 创建四个并行任务
|
||||
triplet_task = self._extract_triplets(dialog_data_list)
|
||||
temporal_task = self._extract_temporal(dialog_data_list)
|
||||
emotion_task = self._extract_emotions(dialog_data_list)
|
||||
embedding_task = self._generate_basic_embeddings(dialog_data_list)
|
||||
|
||||
# 并行执行
|
||||
results = await asyncio.gather(
|
||||
triplet_task,
|
||||
temporal_task,
|
||||
emotion_task,
|
||||
embedding_task,
|
||||
return_exceptions=True
|
||||
)
|
||||
@@ -588,19 +696,21 @@ class ExtractionOrchestrator:
|
||||
# 解包结果
|
||||
triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list]
|
||||
temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list]
|
||||
emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list]
|
||||
|
||||
if isinstance(results[2], Exception):
|
||||
logger.error(f"基础嵌入生成失败: {results[2]}")
|
||||
if isinstance(results[3], Exception):
|
||||
logger.error(f"基础嵌入生成失败: {results[3]}")
|
||||
statement_embedding_maps = [{} for _ in dialog_data_list]
|
||||
chunk_embedding_maps = [{} for _ in dialog_data_list]
|
||||
dialog_embeddings = [[] for _ in dialog_data_list]
|
||||
else:
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2]
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3]
|
||||
|
||||
logger.info("并行任务执行完成")
|
||||
return (
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -711,6 +821,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: List[DialogData],
|
||||
temporal_maps: List[Dict[str, Any]],
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
emotion_maps: List[Dict[str, Any]],
|
||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||
dialog_embeddings: List[List[float]],
|
||||
@@ -722,6 +833,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: 对话数据列表
|
||||
temporal_maps: 时间信息映射列表
|
||||
triplet_maps: 三元组映射列表
|
||||
emotion_maps: 情绪信息映射列表
|
||||
statement_embedding_maps: 陈述句嵌入映射列表
|
||||
chunk_embedding_maps: 分块嵌入映射列表
|
||||
dialog_embeddings: 对话嵌入列表
|
||||
@@ -736,6 +848,7 @@ class ExtractionOrchestrator:
|
||||
if (
|
||||
len(temporal_maps) != expected_length
|
||||
or len(triplet_maps) != expected_length
|
||||
or len(emotion_maps) != expected_length
|
||||
or len(statement_embedding_maps) != expected_length
|
||||
or len(chunk_embedding_maps) != expected_length
|
||||
or len(dialog_embeddings) != expected_length
|
||||
@@ -743,6 +856,7 @@ class ExtractionOrchestrator:
|
||||
logger.warning(
|
||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||
f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, "
|
||||
f"情绪映射: {len(emotion_maps)}, "
|
||||
f"陈述句嵌入: {len(statement_embedding_maps)}, "
|
||||
f"分块嵌入: {len(chunk_embedding_maps)}, "
|
||||
f"对话嵌入: {len(dialog_embeddings)}"
|
||||
@@ -751,6 +865,7 @@ class ExtractionOrchestrator:
|
||||
total_statements = 0
|
||||
assigned_temporal = 0
|
||||
assigned_triplets = 0
|
||||
assigned_emotions = 0
|
||||
assigned_statement_embeddings = 0
|
||||
assigned_chunk_embeddings = 0
|
||||
assigned_dialog_embeddings = 0
|
||||
@@ -758,12 +873,13 @@ class ExtractionOrchestrator:
|
||||
# 处理每个对话
|
||||
for i, dialog_data in enumerate(dialog_data_list):
|
||||
# 检查是否有缺失的数据
|
||||
if i >= len(temporal_maps) or i >= len(triplet_maps):
|
||||
if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps):
|
||||
logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值")
|
||||
continue
|
||||
|
||||
temporal_map = temporal_maps[i]
|
||||
triplet_map = triplet_maps[i]
|
||||
emotion_map = emotion_maps[i]
|
||||
statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {}
|
||||
chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {}
|
||||
dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else []
|
||||
@@ -794,6 +910,18 @@ class ExtractionOrchestrator:
|
||||
statement.triplet_extraction_info = triplet_map[statement.id]
|
||||
assigned_triplets += 1
|
||||
|
||||
# 赋值情绪信息
|
||||
if statement.id in emotion_map:
|
||||
emotion_data = emotion_map[statement.id]
|
||||
if emotion_data is not None:
|
||||
# 将EmotionExtraction对象的字段赋值到Statement
|
||||
statement.emotion_type = emotion_data.emotion_type
|
||||
statement.emotion_intensity = emotion_data.emotion_intensity
|
||||
statement.emotion_keywords = emotion_data.emotion_keywords
|
||||
statement.emotion_subject = emotion_data.emotion_subject
|
||||
statement.emotion_target = emotion_data.emotion_target
|
||||
assigned_emotions += 1
|
||||
|
||||
# 赋值陈述句嵌入
|
||||
if statement.id in statement_embedding_map:
|
||||
statement.statement_embedding = statement_embedding_map[statement.id]
|
||||
@@ -802,6 +930,7 @@ class ExtractionOrchestrator:
|
||||
logger.info(
|
||||
f"数据赋值完成 - 总陈述句: {total_statements}, "
|
||||
f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, "
|
||||
f"情绪信息: {assigned_emotions}, "
|
||||
f"陈述句嵌入: {assigned_statement_embeddings}, "
|
||||
f"分块嵌入: {assigned_chunk_embeddings}, "
|
||||
f"对话嵌入: {assigned_dialog_embeddings}"
|
||||
@@ -833,9 +962,7 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
logger.info("开始创建节点和边")
|
||||
|
||||
# 进度回调:正在创建节点和边
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
# 注意:开始消息已在 run 方法中发送,这里不再重复发送
|
||||
|
||||
dialogue_nodes = []
|
||||
chunk_nodes = []
|
||||
@@ -847,8 +974,13 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于去重的集合
|
||||
entity_id_set = set()
|
||||
|
||||
# 用于跟踪进度
|
||||
total_dialogs = len(dialog_data_list)
|
||||
processed_dialogs = 0
|
||||
|
||||
for dialog_data in dialog_data_list:
|
||||
processed_dialogs += 1
|
||||
# 创建对话节点
|
||||
dialogue_node = DialogueNode(
|
||||
id=dialog_data.id,
|
||||
@@ -908,6 +1040,12 @@ class ExtractionOrchestrator:
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
# Emotion fields
|
||||
emotion_type=getattr(statement, 'emotion_type', None),
|
||||
emotion_intensity=getattr(statement, 'emotion_intensity', None),
|
||||
emotion_keywords=getattr(statement, 'emotion_keywords', None),
|
||||
emotion_subject=getattr(statement, 'emotion_subject', None),
|
||||
emotion_target=getattr(statement, 'emotion_target', None),
|
||||
)
|
||||
statement_nodes.append(statement_node)
|
||||
|
||||
@@ -995,6 +1133,26 @@ class ExtractionOrchestrator:
|
||||
expired_at=dialog_data.expired_at,
|
||||
)
|
||||
entity_entity_edges.append(entity_entity_edge)
|
||||
|
||||
# 流式输出:每创建一个关系边,立即发送进度(限制发送数量)
|
||||
if self.progress_callback and len(entity_entity_edges) <= 10:
|
||||
# 获取实体名称
|
||||
source_name = triplet.subject_name
|
||||
target_name = triplet.object_name
|
||||
relationship_result = {
|
||||
"result_type": "relationship_creation",
|
||||
"relationship_index": len(entity_entity_edges),
|
||||
"source_entity": source_name,
|
||||
"relation_type": triplet.predicate,
|
||||
"target_entity": target_name,
|
||||
"relationship_text": f"{source_name} -[{triplet.predicate}]-> {target_name}",
|
||||
"dialog_progress": f"{processed_dialogs}/{total_dialogs}"
|
||||
}
|
||||
await self.progress_callback(
|
||||
"creating_nodes_edges_result",
|
||||
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
|
||||
relationship_result
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
||||
@@ -1009,12 +1167,9 @@ class ExtractionOrchestrator:
|
||||
f"实体-实体边: {len(entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:只输出关系创建结果
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
# 注意:具体的关系创建结果已经在创建过程中实时发送了
|
||||
if self.progress_callback:
|
||||
# 输出关系创建结果
|
||||
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
|
||||
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
nodes_edges_stats = {
|
||||
"dialogue_nodes_count": len(dialogue_nodes),
|
||||
"chunk_nodes_count": len(chunk_nodes),
|
||||
@@ -1072,7 +1227,7 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
logger.info("开始两阶段实体去重和消歧")
|
||||
|
||||
# 进度回调:正在去重消歧
|
||||
# 进度回调:发送去重消歧开始消息
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("deduplication", "正在去重消歧...")
|
||||
|
||||
@@ -1157,25 +1312,26 @@ class ExtractionOrchestrator:
|
||||
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:输出去重消歧的具体结果
|
||||
# 流式输出:实时输出去重消歧的具体结果
|
||||
if self.progress_callback:
|
||||
# 分析实体合并情况
|
||||
# 分析实体合并情况(使用内存中的记录)
|
||||
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出去重合并的实体示例
|
||||
# 逐个输出去重合并的实体示例
|
||||
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
|
||||
dedup_result = {
|
||||
"result_type": "entity_merge",
|
||||
"merged_entity_name": merge_detail["main_entity_name"],
|
||||
"merged_count": merge_detail["merged_count"],
|
||||
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
|
||||
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
|
||||
|
||||
# 分析实体消歧情况
|
||||
# 分析实体消歧情况(使用内存中的记录)
|
||||
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出实体消歧的结果
|
||||
# 逐个输出实体消歧的结果
|
||||
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
|
||||
disamb_result = {
|
||||
"result_type": "entity_disambiguation",
|
||||
@@ -1183,11 +1339,10 @@ class ExtractionOrchestrator:
|
||||
"disambiguation_type": disamb_detail["disamb_type"],
|
||||
"confidence": disamb_detail.get("confidence", "unknown"),
|
||||
"reason": disamb_detail.get("reason", ""),
|
||||
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
|
||||
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
|
||||
|
||||
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
|
||||
|
||||
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
|
||||
await self._send_dedup_progress_callback(
|
||||
@@ -1299,7 +1454,7 @@ class ExtractionOrchestrator:
|
||||
if match:
|
||||
entity1_name = match.group(1).strip()
|
||||
entity1_type = match.group(2)
|
||||
entity2_name = match.group(3).strip()
|
||||
match.group(3).strip()
|
||||
entity2_type = match.group(4)
|
||||
|
||||
# 提取置信度和原因
|
||||
@@ -1611,7 +1766,6 @@ async def get_chunked_dialogs(
|
||||
包含分块的 DialogData 对象列表
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
# 加载测试数据
|
||||
@@ -1794,7 +1948,6 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
import os
|
||||
print("\n=== 完整数据处理流程(包含预处理)===")
|
||||
|
||||
if input_data_path is None:
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
{
|
||||
"memory_verify": {
|
||||
"source_data": [
|
||||
{
|
||||
"statement_name": "用户是2023年春天去北京工作的。",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户后来基本一直都在北京上班。",
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。",
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从来没有长期离开过北京。",
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。",
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的银行卡号是6222023847595898。",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。",
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
}
|
||||
],
|
||||
"databasets": [
|
||||
{
|
||||
"entity1_name": "Person",
|
||||
"description": "表示人类个体的通用类型",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "用户",
|
||||
"entity2": {
|
||||
"entity_idx": 0,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Person",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "用户",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3d3896797b334572a80d57590026063d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份信息",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用于个人身份识别的数据",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Information",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份信息",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "aa766a517e82490599a9b3af54cfd933"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "6222023847595898",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用户的银行卡号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Numeric",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "6222023847595898",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "上海办公室",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["上海办"],
|
||||
"connect_strength": "Strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "位于上海的工作办公场所",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "上海办公室",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "fb702ef695c14e14af3e56786bc8815b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "北京",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["京", "京城", "北平"],
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "北京",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "11010119950308123X",
|
||||
"description": "具体的身份证号码值",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份证号",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"description": "中华人民共和国公民的身份号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Identifier",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份证号",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3e5f920645b2404fadb0e9ff60d1306e"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -8,17 +8,21 @@
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
@@ -33,14 +37,14 @@ else:
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
RETRIEVAL = "retrieval" # 从检索结果中反思
|
||||
DATABASE = "database" # 从整个数据库中反思
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
|
||||
|
||||
@@ -48,9 +52,16 @@ class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
concurrency: int = Field(default=5, description="并发数量")
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@@ -75,16 +86,16 @@ class ReflectionEngine:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReflectionConfig,
|
||||
neo4j_connector: Optional[Any] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
get_data_func: Optional[Any] = None,
|
||||
render_evaluate_prompt_func: Optional[Any] = None,
|
||||
render_reflexion_prompt_func: Optional[Any] = None,
|
||||
conflict_schema: Optional[Any] = None,
|
||||
reflexion_schema: Optional[Any] = None,
|
||||
update_query: Optional[str] = None
|
||||
self,
|
||||
config: ReflectionConfig,
|
||||
neo4j_connector: Optional[Any] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
get_data_func: Optional[Any] = None,
|
||||
render_evaluate_prompt_func: Optional[Any] = None,
|
||||
render_reflexion_prompt_func: Optional[Any] = None,
|
||||
conflict_schema: Optional[Any] = None,
|
||||
reflexion_schema: Optional[Any] = None,
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
@@ -109,7 +120,7 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(config.concurrency)
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
self._lazy_init_done = False
|
||||
@@ -127,11 +138,21 @@ class ReflectionEngine:
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
model_id = self.llm_client
|
||||
self.llm_client = get_llm_client(model_id)
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
from app.core.memory.utils.config.get_data import get_data_statement
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
self.render_evaluate_prompt_func = render_evaluate_prompt
|
||||
@@ -154,13 +175,11 @@ class ReflectionEngine:
|
||||
|
||||
self._lazy_init_done = True
|
||||
|
||||
async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult:
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
"""
|
||||
@@ -176,9 +195,10 @@ class ReflectionEngine:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
reflexion_data = await self._get_reflexion_data(host_id)
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
success=True,
|
||||
@@ -187,22 +207,21 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data)
|
||||
if not conflict_data:
|
||||
return ReflectionResult(
|
||||
success=True,
|
||||
message="无冲突,无需反思",
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
|
||||
conflicts_found = len(conflict_data)
|
||||
logging.info(f"发现 {conflicts_found} 个冲突")
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data)
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
@@ -210,6 +229,9 @@ class ReflectionEngine:
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
print(100 * '*')
|
||||
print(solved_data)
|
||||
print(100 * '*')
|
||||
|
||||
conflicts_resolved = len(solved_data)
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
@@ -230,7 +252,8 @@ class ReflectionEngine:
|
||||
conflicts_found=conflicts_found,
|
||||
conflicts_resolved=conflicts_resolved,
|
||||
memories_updated=memories_updated,
|
||||
execution_time=execution_time
|
||||
execution_time=execution_time,
|
||||
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -241,6 +264,79 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
async def reflection_run(self):
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
|
||||
asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
|
||||
result_data = {}
|
||||
|
||||
source_data, databasets = await self.extract_fields_from_json()
|
||||
result_data['baseline'] = self.config.baseline
|
||||
result_data[
|
||||
'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
print(item)
|
||||
quality_assessments.append(item['quality_assessment'])
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
message="反思失败,未解决冲突",
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
reflexion_data.append(result['reflexion'])
|
||||
result_data['reflexion_data'] = reflexion_data
|
||||
return result_data
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
|
||||
return source_data, databasets
|
||||
|
||||
except Exception as e:
|
||||
return [], []
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
@@ -253,17 +349,28 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
"""
|
||||
if self.config.reflexion_range == ReflectionRange.RETRIEVAL:
|
||||
# 从检索结果中获取数据
|
||||
return await self.get_data_func(host_id)
|
||||
elif self.config.reflexion_range == ReflectionRange.DATABASE:
|
||||
# 从整个数据库中获取数据(待实现)
|
||||
logging.warning("从数据库获取反思数据功能尚未实现")
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {self.config.reflexion_range}")
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any]) -> List[Any]:
|
||||
|
||||
|
||||
if self.config.reflexion_range == ReflectionRange.PARTIAL:
|
||||
neo4j_query = neo4j_query_part.format(host_id)
|
||||
neo4j_statement = neo4j_statement_part.format(host_id)
|
||||
elif self.config.reflexion_range == ReflectionRange.ALL:
|
||||
neo4j_query = neo4j_query_all.format(host_id)
|
||||
neo4j_statement = neo4j_statement_all.format(host_id)
|
||||
try:
|
||||
result = await self.neo4j_connector.execute_query(neo4j_query)
|
||||
result_statement = await self.neo4j_connector.execute_query(neo4j_statement)
|
||||
neo4j_databasets = await self.get_data_func(result)
|
||||
neo4j_state = await self.get_data_statement(result_statement)
|
||||
return neo4j_databasets, neo4j_state
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Neo4j查询失败: {e}")
|
||||
return [], []
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
@@ -278,14 +385,28 @@ class ReflectionEngine:
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
quality_assessment = self.config.quality_assessment
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema
|
||||
self.conflict_schema,
|
||||
self.config.baseline,
|
||||
memory_verify,
|
||||
quality_assessment,
|
||||
statement_databasets
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -316,7 +437,7 @@ class ReflectionEngine:
|
||||
logging.error(f"冲突检测失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]:
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
@@ -332,6 +453,8 @@ class ReflectionEngine:
|
||||
return []
|
||||
|
||||
logging.info("====== 冲突解决开始 ======")
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
@@ -341,7 +464,10 @@ class ReflectionEngine:
|
||||
# 渲染反思提示词
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema
|
||||
self.reflexion_schema,
|
||||
baseline,
|
||||
memory_verify,
|
||||
statement_databasets
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -381,8 +507,8 @@ class ReflectionEngine:
|
||||
return solved
|
||||
|
||||
async def _apply_reflection_results(
|
||||
self,
|
||||
solved_data: List[Dict[str, Any]]
|
||||
self,
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
@@ -395,57 +521,7 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
"""
|
||||
if not solved_data:
|
||||
logging.warning("无解决方案数据,跳过更新")
|
||||
return 0
|
||||
|
||||
logging.info("====== 记忆更新开始 ======")
|
||||
|
||||
success_count = 0
|
||||
|
||||
async def _update_one(item: Dict[str, Any]) -> bool:
|
||||
"""更新单条记忆"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
|
||||
# 提取更新参数
|
||||
resolved = item.get("resolved", {})
|
||||
resolved_mem = resolved.get("resolved_memory", {})
|
||||
group_id = resolved_mem.get("group_id")
|
||||
memory_id = resolved_mem.get("id")
|
||||
new_invalid_at = resolved_mem.get("invalid_at")
|
||||
|
||||
if not all([group_id, memory_id, new_invalid_at]):
|
||||
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
|
||||
return False
|
||||
|
||||
# 执行更新
|
||||
await self.neo4j_connector.execute_query(
|
||||
self.update_query,
|
||||
group_id=group_id,
|
||||
id=memory_id,
|
||||
new_invalid_at=new_invalid_at,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"更新单条记忆失败: {e}")
|
||||
return False
|
||||
|
||||
# 并发执行所有更新任务
|
||||
tasks = [
|
||||
_update_one(item)
|
||||
for item in solved_data
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
success_count = sum(1 for r in results if r)
|
||||
|
||||
logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆")
|
||||
|
||||
success_count = await neo4j_data(solved_data)
|
||||
return success_count
|
||||
|
||||
async def _log_data(self, label: str, data: Any) -> None:
|
||||
@@ -456,6 +532,7 @@ class ReflectionEngine:
|
||||
label: 数据标签
|
||||
data: 要记录的数据
|
||||
"""
|
||||
|
||||
def _write():
|
||||
try:
|
||||
with open("reflexion_data.json", "a", encoding="utf-8") as f:
|
||||
@@ -470,9 +547,9 @@ class ReflectionEngine:
|
||||
|
||||
# 基于时间的反思方法
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
@@ -494,8 +571,8 @@ class ReflectionEngine:
|
||||
|
||||
# 基于事实的反思方法
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
@@ -515,8 +592,8 @@ class ReflectionEngine:
|
||||
|
||||
# 综合反思方法
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
@@ -553,33 +630,3 @@ class ReflectionEngine:
|
||||
else:
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
# 便捷函数:创建默认配置的反思引擎
|
||||
def create_reflection_engine(
|
||||
enabled: bool = False,
|
||||
iteration_period: str = "3",
|
||||
reflexion_range: str = "retrieval",
|
||||
baseline: str = "TIME",
|
||||
concurrency: int = 5
|
||||
) -> ReflectionEngine:
|
||||
"""
|
||||
创建反思引擎实例
|
||||
|
||||
Args:
|
||||
enabled: 是否启用反思
|
||||
iteration_period: 反思周期
|
||||
reflexion_range: 反思范围
|
||||
baseline: 反思基线
|
||||
concurrency: 并发数量
|
||||
|
||||
Returns:
|
||||
ReflectionEngine: 反思引擎实例
|
||||
"""
|
||||
config = ReflectionConfig(
|
||||
enabled=enabled,
|
||||
iteration_period=iteration_period,
|
||||
reflexion_range=reflexion_range,
|
||||
baseline=baseline,
|
||||
concurrency=concurrency
|
||||
)
|
||||
return ReflectionEngine(config)
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.schemas.memory_storage_schema import BaseDataSchema
|
||||
|
||||
import logging
|
||||
|
||||
from typing import List, Dict, Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
@@ -60,27 +55,46 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
return results
|
||||
|
||||
|
||||
async def get_data(host_id: uuid.UUID) -> List[Dict]:
|
||||
async def get_data(result):
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all()
|
||||
neo4j_databasets=[]
|
||||
for item in result:
|
||||
filtered_item = {}
|
||||
for key, value in item.items():
|
||||
if 'name_embedding' not in key.lower():
|
||||
if key == 'relationship' and value is not None:
|
||||
# 只保留relationship的指定字段
|
||||
rel_filtered = {}
|
||||
if hasattr(value, 'get'):
|
||||
rel_filtered['run_id'] = value.get('run_id')
|
||||
rel_filtered['statement'] = value.get('statement')
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = rel_filtered
|
||||
elif key == 'entity2' and value is not None:
|
||||
# 过滤entity2的name_embedding字段
|
||||
entity2_filtered = {}
|
||||
if hasattr(value, 'items'):
|
||||
for e_key, e_value in value.items():
|
||||
if 'name_embedding' not in e_key.lower():
|
||||
entity2_filtered[e_key] = e_value
|
||||
filtered_item[key] = entity2_filtered
|
||||
else:
|
||||
filtered_item[key] = value
|
||||
|
||||
# 直接将字典添加到列表中
|
||||
neo4j_databasets.append(filtered_item)
|
||||
return neo4j_databasets
|
||||
async def get_data_statement( result):
|
||||
neo4j_databasets=[]
|
||||
for i in result:
|
||||
neo4j_databasets.append(i)
|
||||
return neo4j_databasets
|
||||
|
||||
|
||||
# print(f"data:\n{data}")
|
||||
# 解析,提取为字典的列表
|
||||
results = await _load_(data)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -238,3 +238,81 @@ async def render_memory_summary_prompt(
|
||||
'json_schema': 'MemorySummaryResponse.schema'
|
||||
})
|
||||
return rendered_prompt
|
||||
|
||||
async def render_emotion_extraction_prompt(
|
||||
statement: str,
|
||||
extract_keywords: bool,
|
||||
enable_subject: bool
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion extraction prompt using the extract_emotion.jinja2 template.
|
||||
|
||||
Args:
|
||||
statement: The statement to analyze
|
||||
extract_keywords: Whether to extract emotion keywords
|
||||
enable_subject: Whether to enable subject classification
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_emotion.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
extract_keywords=extract_keywords,
|
||||
enable_subject=enable_subject
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('emotion extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_emotion.jinja2', {
|
||||
'statement': 'str',
|
||||
'extract_keywords': extract_keywords,
|
||||
'enable_subject': enable_subject
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
async def render_emotion_suggestions_prompt(
|
||||
health_data: dict,
|
||||
patterns: dict,
|
||||
user_profile: dict
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion suggestions generation prompt using the generate_emotion_suggestions.jinja2 template.
|
||||
|
||||
Args:
|
||||
health_data: 情绪健康数据
|
||||
patterns: 情绪模式分析结果
|
||||
user_profile: 用户画像数据
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
import json
|
||||
|
||||
# 预处理 emotion_distribution 为 JSON 字符串
|
||||
emotion_distribution_json = json.dumps(
|
||||
health_data.get('emotion_distribution', {}),
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
|
||||
template = prompt_env.get_template("generate_emotion_suggestions.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
health_data=health_data,
|
||||
patterns=patterns,
|
||||
user_profile=user_profile,
|
||||
emotion_distribution_json=emotion_distribution_json
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('emotion suggestions', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('generate_emotion_suggestions.jinja2', {
|
||||
'health_score': health_data.get('health_score'),
|
||||
'health_level': health_data.get('level'),
|
||||
'user_interests': user_profile.get('interests', [])
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,19 +1,222 @@
|
||||
你将收到一组记忆对象:{{ evaluate_data }}。
|
||||
任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突)
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数:
|
||||
原本的输入句子:{{statement_databasets}}
|
||||
需要检测冲突对象:{{ evaluate_data }}
|
||||
冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID)
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false)
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
你的任务是:
|
||||
对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估)
|
||||
## 冲突定义
|
||||
|
||||
### 时间冲突
|
||||
时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾:
|
||||
|
||||
1. **同一活动的时间冲突**:
|
||||
- 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球")
|
||||
- 同一用户在同一时间段内被记录进行不同的互斥活动
|
||||
|
||||
2. **时间逻辑错误**:
|
||||
- expired_at 早于 created_at
|
||||
- 同一事实的 created_at 时间差异超过合理误差范围(>5分钟)
|
||||
|
||||
3. **日期属性冲突**:
|
||||
- 同一人的生日记录为不同日期(如"2月10号"和"2月16号")
|
||||
4.存在明确先后约束 A -> B,但 t(A) > t(B)
|
||||
-例:入学时间晚于毕业时间。
|
||||
-处理:标记异常、降权、触发逻辑反思或人工审查。
|
||||
5.时间属性冲突
|
||||
-单值日期属性出现多值(生日、入职日期)
|
||||
-注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。
|
||||
6.互斥重叠冲突
|
||||
-例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地)
|
||||
-处理:证据仲裁、保留多版本(active + candidate)。
|
||||
|
||||
|
||||
|
||||
### 事实冲突
|
||||
事实冲突是指同一实体的属性或关系存在相互矛盾的陈述:
|
||||
|
||||
1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
|
||||
### 混合冲突检测
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:
|
||||
检测任何逻辑上不一致或相互矛盾的记录
|
||||
## 记忆审核定义
|
||||
|
||||
### 隐私信息检测(隐私冲突)
|
||||
当memory_verify为true时,需要额外检测包含个人隐私信息的记录:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私检测原则
|
||||
- 检测description、entity1_name、entity2_name等字段中的隐私信息
|
||||
- 识别数字模式(如手机号11位数字、身份证18位等)
|
||||
- 识别关键词(如"身份证"、"银行卡"、"密码"等)
|
||||
- 检测敏感实体类型和关系
|
||||
|
||||
## 冲突检测原则
|
||||
|
||||
**全面检测**:不区分冲突类型,检测所有可能的冲突
|
||||
**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段
|
||||
**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录
|
||||
**语义分析**:分析description字段的语义相似性和冲突性
|
||||
**时间逻辑**:检查时间字段的逻辑一致性
|
||||
**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录
|
||||
|
||||
## 不符合冲突检测
|
||||
-称呼
|
||||
## 重要检测示例
|
||||
|
||||
### 冲突检测示例
|
||||
- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号)
|
||||
- 同一实体的重复定义但描述不同
|
||||
- 同一关系的不同表述但含义冲突
|
||||
- 任何逻辑上不可能同时为真的记录
|
||||
|
||||
### 隐私信息检测示例
|
||||
- 包含手机号的记录:"用户的手机号是13812345678"
|
||||
- 包含身份证的记录:"身份证号码为110101199001011234"
|
||||
- 包含银行卡的记录:"银行卡号6222021234567890"
|
||||
- 包含社交账号的记录:"微信号是user123456"
|
||||
- 包含敏感信息的实体名称或描述
|
||||
|
||||
## 输出要求
|
||||
|
||||
**关键原则**:
|
||||
1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录
|
||||
2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中
|
||||
3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中
|
||||
4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组
|
||||
5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null
|
||||
6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响
|
||||
7. 不输出conflict_memory字段
|
||||
|
||||
**处理逻辑**:
|
||||
- 首先进行冲突检测,将冲突记录加入data数组
|
||||
- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组
|
||||
- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果
|
||||
- 最终data数组包含所有冲突记录和隐私信息记录(去重)
|
||||
- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果
|
||||
- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
|
||||
## 记忆质量评估定义
|
||||
|
||||
### 质量评估标准
|
||||
当quality_assessment为true时,需要对记忆数据进行质量评估:
|
||||
|
||||
1. **数据完整性**:
|
||||
- 检查必要字段是否完整(entity1_name、entity2_name、description等)
|
||||
- 检查关系描述是否清晰明确
|
||||
- 检查时间字段的有效性
|
||||
|
||||
2. **重复字段检测**:
|
||||
- 识别相同或高度相似的记录
|
||||
- 检测冗余的实体关系
|
||||
- 分析描述内容的重复度
|
||||
|
||||
3. **无意义字段检测**:
|
||||
- 识别空值、无效值或占位符内容
|
||||
- 检测过于简单或无信息量的描述
|
||||
- 识别格式错误或不规范的数据
|
||||
|
||||
4. **上下文依赖性**:
|
||||
- 评估记录是否需要额外上下文才能理解
|
||||
- 检查实体名称的明确性
|
||||
- 分析关系描述的自包含性
|
||||
|
||||
### 质量评估输出
|
||||
- **质量百分比**:基于上述标准计算的整体质量分数(0-100)
|
||||
- **质量概述**:简要描述数据质量状况,包括主要问题和优点
|
||||
|
||||
输出是仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
{
|
||||
"data": [ ...与输入同结构的记忆对象数组... ],
|
||||
"conflict": true 或 false,
|
||||
"conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null
|
||||
"data": [
|
||||
{
|
||||
"entity1_name": "实体1名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间戳",
|
||||
"expired_at": "过期时间戳",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": "关系对象",
|
||||
"entity2_name": "实体2名称",
|
||||
"entity2": "实体2对象"
|
||||
}
|
||||
],
|
||||
"conflict": true或false,
|
||||
"quality_assessment": {
|
||||
"score": 质量百分比数字,
|
||||
"summary": "质量概述文本"
|
||||
} 或 null,
|
||||
"memory_verify": {
|
||||
"has_privacy": true或false,
|
||||
"privacy_types": ["检测到的隐私信息类型列表"],
|
||||
"summary": "隐私检测结果概述"
|
||||
} 或 null
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。
|
||||
- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。
|
||||
- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。
|
||||
|
||||
### memory_verify字段说明
|
||||
当memory_verify为true时,需要输出隐私检测结果:
|
||||
- **has_privacy**: 布尔值,表示是否检测到隐私信息
|
||||
- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"])
|
||||
- **summary**: 字符串,简要描述隐私检测结果
|
||||
|
||||
当memory_verify为false时,memory_verify字段输出null。
|
||||
|
||||
### memory_verify字段示例
|
||||
|
||||
**示例1:检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": true,
|
||||
"privacy_types": ["手机号码", "身份证信息"],
|
||||
"summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码"
|
||||
}
|
||||
```
|
||||
|
||||
**示例2:未检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": false,
|
||||
"privacy_types": [],
|
||||
"summary": "未检测到隐私信息"
|
||||
}
|
||||
```
|
||||
|
||||
**示例3:memory_verify为false时**
|
||||
```json
|
||||
"memory_verify": null
|
||||
```
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
{{ json_schema }}
|
||||
@@ -0,0 +1,57 @@
|
||||
你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。
|
||||
|
||||
陈述句:{{ statement }}
|
||||
|
||||
请提取以下信息:
|
||||
|
||||
1. emotion_type(情绪类型):
|
||||
- joy: 喜悦、开心、高兴、满意、愉快
|
||||
- sadness: 悲伤、难过、失落、沮丧、遗憾
|
||||
- anger: 愤怒、生气、不满、恼火、烦躁
|
||||
- fear: 恐惧、害怕、担心、焦虑、紧张
|
||||
- surprise: 惊讶、意外、震惊、吃惊
|
||||
- neutral: 中性、客观陈述、无明显情绪
|
||||
|
||||
2. emotion_intensity(情绪强度):
|
||||
- 0.0-0.3: 弱情绪
|
||||
- 0.3-0.7: 中等情绪
|
||||
- 0.7-1.0: 强情绪
|
||||
|
||||
{% if extract_keywords %}
|
||||
3. emotion_keywords(情绪关键词):
|
||||
- 原句中直接表达情绪的词语
|
||||
- 最多提取3个关键词
|
||||
- 如果没有明显的情绪词,返回空列表
|
||||
{% else %}
|
||||
3. emotion_keywords(情绪关键词):
|
||||
- 返回空列表
|
||||
{% endif %}
|
||||
|
||||
{% if enable_subject %}
|
||||
4. emotion_subject(情绪主体):
|
||||
- self: 用户本人的情绪(包含"我"、"我们"、"咱们"等第一人称)
|
||||
- other: 他人的情绪(包含人名、"他/她"等第三人称)
|
||||
- object: 对事物的评价(针对产品、地点、事件等)
|
||||
|
||||
注意:
|
||||
- 如果同时包含多个主体,优先识别用户本人(self)
|
||||
- 如果无法明确判断主体,默认为 self
|
||||
|
||||
5. emotion_target(情绪对象):
|
||||
- 如果有明确的情绪对象,提取其名称
|
||||
- 如果没有明确对象,返回 null
|
||||
{% else %}
|
||||
4. emotion_subject(情绪主体):
|
||||
- 默认为 self
|
||||
|
||||
5. emotion_target(情绪对象):
|
||||
- 返回 null
|
||||
{% endif %}
|
||||
|
||||
注意事项:
|
||||
- 如果陈述句是客观事实陈述,无明显情绪,标记为 neutral
|
||||
- 情绪强度要符合语境,不要过度解读
|
||||
- 情绪关键词要准确,不要添加原句中没有的词
|
||||
- 主体分类要准确,优先识别用户本人(self)
|
||||
|
||||
请以 JSON 格式返回结果。
|
||||
@@ -0,0 +1,63 @@
|
||||
你是一位专业的心理健康顾问。请根据以下用户的情绪健康数据和个人信息,生成3-5条个性化的情绪改善建议。
|
||||
|
||||
## 用户情绪健康数据
|
||||
|
||||
健康分数:{{ health_data.health_score }}/100
|
||||
健康等级:{{ health_data.level }}
|
||||
|
||||
维度分析:
|
||||
- 积极率:{{ health_data.dimensions.positivity_rate.score }}/100
|
||||
- 正面情绪:{{ health_data.dimensions.positivity_rate.positive_count }}次
|
||||
- 负面情绪:{{ health_data.dimensions.positivity_rate.negative_count }}次
|
||||
- 中性情绪:{{ health_data.dimensions.positivity_rate.neutral_count }}次
|
||||
|
||||
- 稳定性:{{ health_data.dimensions.stability.score }}/100
|
||||
- 标准差:{{ health_data.dimensions.stability.std_deviation }}
|
||||
|
||||
- 恢复力:{{ health_data.dimensions.resilience.score }}/100
|
||||
- 恢复率:{{ health_data.dimensions.resilience.recovery_rate }}
|
||||
|
||||
情绪分布:
|
||||
{{ emotion_distribution_json }}
|
||||
|
||||
## 情绪模式分析
|
||||
|
||||
主要负面情绪:{{ patterns.dominant_negative_emotion|default('无') }}
|
||||
情绪波动性:{{ patterns.emotion_volatility|default('未知') }}
|
||||
高强度情绪次数:{{ patterns.high_intensity_emotions|default([])|length }}
|
||||
|
||||
## 用户兴趣
|
||||
|
||||
{{ user_profile.interests|default(['未知'])|join(', ') }}
|
||||
|
||||
## 任务要求
|
||||
|
||||
请生成3-5条个性化建议,每条建议包含:
|
||||
1. type: 建议类型(emotion_balance/activity_recommendation/social_connection/stress_management)
|
||||
2. title: 建议标题(简短有力)
|
||||
3. content: 建议内容(详细说明,50-100字)
|
||||
4. priority: 优先级(high/medium/low)
|
||||
5. actionable_steps: 3个可执行的具体步骤
|
||||
|
||||
同时提供一个health_summary(不超过50字),概括用户的整体情绪状态。
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{
|
||||
"health_summary": "您的情绪健康状况...",
|
||||
"suggestions": [
|
||||
{
|
||||
"type": "emotion_balance",
|
||||
"title": "建议标题",
|
||||
"content": "建议内容...",
|
||||
"priority": "high",
|
||||
"actionable_steps": ["步骤1", "步骤2", "步骤3"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
注意事项:
|
||||
- 建议要具体、可执行,避免空泛
|
||||
- 结合用户的兴趣爱好提供个性化建议
|
||||
- 针对主要问题(如主要负面情绪)提供针对性建议
|
||||
- 优先级要合理分配(至少1个high,1-2个medium,其余low)
|
||||
- 每个建议的3个步骤要循序渐进、易于实施
|
||||
@@ -1,23 +1,300 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j)
|
||||
你将收到一条冲突判定对象:{{ data }}。
|
||||
任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。
|
||||
需要检测冲突对象:{{ statement_databasets }}
|
||||
以及需要识别的冲突对象为:{{ baseline }}
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
|
||||
角色:
|
||||
- 你是数据领域中解决数据冲突的专家
|
||||
|
||||
任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。
|
||||
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间
|
||||
|
||||
**处理模式**:
|
||||
- 当memory_verify为false时:仅处理数据冲突
|
||||
- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏
|
||||
|
||||
## 分组处理原则
|
||||
|
||||
**冲突类型识别与分组**:
|
||||
1. **日期冲突**:
|
||||
1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号),
|
||||
1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球)
|
||||
3. **事实属性冲突**:
|
||||
3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3.3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别
|
||||
|
||||
**分组输出要求**:
|
||||
- 每种冲突类型生成一个独立的reflexion_result对象
|
||||
- 同一类型的多个冲突记录归并到一个结果中
|
||||
- 不同类型的冲突分别处理,各自生成独立结果
|
||||
|
||||
## 冲突类型定义
|
||||
|
||||
### 时间冲突(TIME)
|
||||
时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。
|
||||
|
||||
### 事实冲突(FACT)
|
||||
事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。
|
||||
### 混合冲突(HYBRID)
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录
|
||||
{% if memory_verify %}
|
||||
## 隐私信息处理(memory_verify为true时启用)
|
||||
|
||||
### 隐私信息识别
|
||||
需要识别并处理以下类型的隐私信息:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私数据脱敏规则
|
||||
对于检测到的隐私信息,按以下规则进行脱敏处理:
|
||||
|
||||
**数字类隐私信息脱敏**:
|
||||
- 保留前三位和后四位,中间用*代替
|
||||
- 示例:手机号13812345678 → 138****5678
|
||||
- 示例:身份证110101199001011234 → 110***********1234
|
||||
- 示例:银行卡6222021234567890 → 622***********7890
|
||||
|
||||
**文本类隐私信息脱敏**:
|
||||
- 社交账号:保留前三后四位字符,中间用*代替
|
||||
- 示例:微信号user123456 → use****3456
|
||||
- 示例:邮箱zhang.san@example.com → zha****@example.com
|
||||
|
||||
**脱敏处理字段**:
|
||||
- name字段:如包含隐私信息需脱敏
|
||||
- entity1_name字段:如包含隐私信息需脱敏
|
||||
- entity2_name字段:如包含隐私信息需脱敏
|
||||
- description字段:如包含隐私信息需脱敏
|
||||
{% endif %}
|
||||
|
||||
## 工作步骤
|
||||
|
||||
### 第一步:分析冲突类型匹配
|
||||
首先判断输入的冲突数据是否符合baseline要求的类型:
|
||||
|
||||
**类型匹配规则**:
|
||||
- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突)
|
||||
- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理
|
||||
|
||||
**类型识别**:
|
||||
- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等)
|
||||
- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述
|
||||
|
||||
**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null)
|
||||
|
||||
### 第二步:筛选并分组冲突数据
|
||||
按冲突类型对数据进行分组:
|
||||
|
||||
**分组策略**:
|
||||
1. **时间冲突组**:筛选涉及用户时间的所有记录
|
||||
2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录
|
||||
3. **事实冲突组**:筛选涉及同一实体不同属性的记录
|
||||
4. **其他冲突组**:其他类型的冲突记录
|
||||
|
||||
**筛选条件**:
|
||||
- 只处理与baseline匹配的冲突类型
|
||||
- 相同entity1_name但entity2_name不同的记录
|
||||
- 相同关系但描述矛盾的记录
|
||||
- 时间逻辑不一致的记录
|
||||
|
||||
### 第三步:冲突解决策略
|
||||
** 不可以解决的冲突情况
|
||||
1. 数据被判定为正确的情况下,不可以进行修改
|
||||
**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理:
|
||||
|
||||
**智能解决策略**:
|
||||
1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定
|
||||
2. **判断正确答案是否存在**:
|
||||
- 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00)
|
||||
- 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改
|
||||
- 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息
|
||||
|
||||
{% if memory_verify %}
|
||||
**隐私处理集成**:
|
||||
- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏
|
||||
- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏
|
||||
- 在change字段中记录隐私脱敏的变更
|
||||
{% endif %}
|
||||
|
||||
**具体处理规则**:
|
||||
|
||||
**情况1:正确答案存在于data中**
|
||||
- 保留正确的记录不变
|
||||
- 基于时间关系的冲突:
|
||||
需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
||||
- 基于事实的关系冲突
|
||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
||||
- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间)
|
||||
|
||||
**情况2:正确答案不存在于data中**
|
||||
- 选择最合适的记录进行修改
|
||||
- 更新该记录的相关字段:
|
||||
- description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %}
|
||||
- name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %}
|
||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
||||
- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]`
|
||||
|
||||
**重要原则**:
|
||||
- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据
|
||||
- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录
|
||||
- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值
|
||||
{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理
|
||||
- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %}
|
||||
- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空
|
||||
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**类型不匹配处理**:
|
||||
- 如果冲突类型与baseline不匹配,resolved必须设为null
|
||||
- reflexion.reason说明类型不匹配的原因
|
||||
- reflexion.solution说明无需处理
|
||||
|
||||
### 第四步:输出解决方案
|
||||
|
||||
## 输出要求
|
||||
**嵌套字段映射**(系统会自动处理):
|
||||
- `entity2.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `entity1.description` → 自动映射为 `description`
|
||||
- `entity2.description` → 自动映射为 `description`
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 响应必须是与此确切模式匹配的有效JSON对象
|
||||
- 不要在JSON之前或之后包含任何文本
|
||||
|
||||
JSON格式要求:
|
||||
1. JSON结构仅使用标准ASCII双引号(")
|
||||
2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义
|
||||
3. 确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4. JSON字符串值中不包括换行符
|
||||
5. 不允许输出```json```相关符号
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
|
||||
**输出格式:按冲突类型分组的列表**
|
||||
{
|
||||
"conflict": 与输入同结构,包含 data 与 conflict_memory,
|
||||
"reflexion": { "reason": string, "solution": string },
|
||||
"resolved": {
|
||||
"original_memory_id": 被设为失效的记忆 id,
|
||||
"resolved_memory": 完整的设为失效后的记忆对象
|
||||
}
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [该冲突类型相关的数据记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "该冲突类型的原因分析",
|
||||
"solution": "该冲突类型的解决方案"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "被设为失效的记忆id",
|
||||
"resolved_memory": {
|
||||
"entity1_name": "实体1名称",
|
||||
"entity2_name": "实体2名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间",
|
||||
"expired_at": "过期时间",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": {},
|
||||
"entity2": {...}
|
||||
},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**示例:多种冲突类型的输出**
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [生日冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期",
|
||||
"solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "df066210883545a08e727ccd8ad4ec77",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"expired_at": "2025-12-16T12:00:00"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
},
|
||||
{
|
||||
"conflict": {
|
||||
"data": [篮球时间冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突",
|
||||
"solution": "保留最可信的时间记录,将冲突记录设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "另一个记录ID",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"description": "使用系统的个人,指代说话者本人,篮球时间为周六"},
|
||||
{"entity2_name": "周六"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- 当 conflict 为 false 时,resolved 必须为 null。
|
||||
- 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。
|
||||
- 只输出 JSON,不要添加解释或多余文本
|
||||
- 使用标准双引号,必要时对内部引号进行转义
|
||||
- 字段名与结构必须与给定模式一致
|
||||
- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象
|
||||
- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中
|
||||
- **每个result对象的conflict.data**只包含该冲突类型相关的记录
|
||||
- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出
|
||||
- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值
|
||||
- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null
|
||||
- 如果与baseline不匹配的冲突类型,不要在results中包含该类型
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
{{ json_schema }}
|
||||
@@ -7,36 +7,50 @@ from typing import List, Dict, Any
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str:
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate.jinja2 template.
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: The data to evaluate
|
||||
schema: The JSON schema to use for the output.
|
||||
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
|
||||
memory_verify: Whether to enable memory verification for privacy detection
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema)
|
||||
|
||||
rendered_prompt = template.render(
|
||||
evaluate_data=evaluate_data,
|
||||
json_schema=schema,
|
||||
baseline=baseline,
|
||||
memory_verify=memory_verify,
|
||||
quality_assessment=quality_assessment,
|
||||
statement_databasets=statement_databasets
|
||||
)
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str:
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = []) -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the extract_temporal.jinja2 template.
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
Args:
|
||||
data: The data to reflex on.
|
||||
schema: The JSON schema to use for the output.
|
||||
baseline: The baseline type for conflict resolution.
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema)
|
||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
Reference in New Issue
Block a user