style(memory): Some code style optimizations
This commit is contained in:
@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
|
||||
# 处理 Neo4j DateTime 对象
|
||||
if hasattr(v, 'to_native'):
|
||||
return v.to_native()
|
||||
|
||||
|
||||
# 处理 Python datetime 对象
|
||||
if isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
|
||||
if isinstance(v, str):
|
||||
# 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
|
||||
# 支持1-4位年份
|
||||
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
|
||||
match = re.match(pattern, v)
|
||||
|
||||
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
|
||||
minute = int(match.group(5)) if match.group(5) else 0
|
||||
second = int(match.group(6)) if match.group(6) else 0
|
||||
microsecond = 0
|
||||
|
||||
|
||||
# 处理微秒
|
||||
if match.group(7):
|
||||
# 补齐或截断到6位
|
||||
us_str = match.group(7).ljust(6, '0')[:6]
|
||||
microsecond = int(us_str)
|
||||
|
||||
|
||||
# 处理时区
|
||||
tzinfo = None
|
||||
if 'Z' in v or match.group(8):
|
||||
tzinfo = timezone.utc
|
||||
|
||||
|
||||
# 创建 datetime 对象
|
||||
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
|
||||
|
||||
|
||||
except (ValueError, OverflowError):
|
||||
# 日期值无效(如月份13、日期32等)
|
||||
return None
|
||||
|
||||
|
||||
# 如果不匹配模式,尝试使用 fromisoformat(用于标准格式)
|
||||
try:
|
||||
return datetime.fromisoformat(v.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
@@ -206,7 +206,8 @@ class DialogueNode(Node):
|
||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||
content: str = Field(..., description="Dialogue content")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this dialogue (integer or string)")
|
||||
|
||||
|
||||
class StatementNode(Node):
|
||||
@@ -241,17 +242,17 @@ class StatementNode(Node):
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
|
||||
|
||||
# Speaker identification
|
||||
speaker: Optional[str] = Field(
|
||||
None,
|
||||
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||
)
|
||||
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity: 0.0-1.0 (displayed on node)"
|
||||
)
|
||||
@@ -264,25 +265,26 @@ class StatementNode(Node):
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_type: Optional[str] = Field(
|
||||
None,
|
||||
None,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_keywords: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords list, max 3 items"
|
||||
)
|
||||
|
||||
|
||||
# Temporal fields
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -309,13 +311,13 @@ class StatementNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
@field_validator('emotion_type', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
@@ -326,7 +328,7 @@ class StatementNode(Node):
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_subject', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
@@ -337,7 +339,7 @@ class StatementNode(Node):
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_keywords', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
@@ -405,19 +407,20 @@ class ExtractedEntityNode(Node):
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -444,16 +447,16 @@ class ExtractedEntityNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
"""Validate and clean aliases field using utility function.
|
||||
|
||||
This validator ensures that the aliases field is always a valid list of strings.
|
||||
@@ -507,8 +510,9 @@ class MemorySummaryNode(Node):
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
None,
|
||||
@@ -522,7 +526,7 @@ class MemorySummaryNode(Node):
|
||||
None,
|
||||
description="Timestamp when the nodes were merged"
|
||||
)
|
||||
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
|
||||
@@ -227,7 +227,8 @@ class EmbeddingGenerator:
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
print(f"实体 '{entity_texts[i]}' "
|
||||
f"嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
|
||||
Reference in New Issue
Block a user