style(memory): Some code style optimizations

This commit is contained in:
Eternity
2026-03-20 18:22:20 +08:00
parent e8ae46b286
commit c17a2dad2d
8 changed files with 296 additions and 292 deletions

View File

@@ -75,11 +75,6 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 @router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config( def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
@@ -107,9 +102,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}") api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
@@ -119,9 +116,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -217,7 +216,8 @@ def update_config(
# 校验至少有一个字段需要更新 # 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
"config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try: try:
@@ -278,6 +278,7 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}") api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 @router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config( def read_all_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@@ -443,8 +444,6 @@ async def search_entity_edges(
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) @router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api( async def get_hot_memory_tags_api(
limit: int = 10, limit: int = 10,
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
except Exception as e: except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}") api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))

View File

@@ -598,8 +598,10 @@ class LangChainAgent:
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", total_tokens = response_meta.get("token_usage", {}).get(
0) if response_meta else 0 "total_tokens",
0
) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:

View File

@@ -206,7 +206,8 @@ class DialogueNode(Node):
ref_id: str = Field(..., description="Reference identifier of the dialog") ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content") content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node): class StatementNode(Node):
@@ -281,7 +282,8 @@ class StatementNode(Node):
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
@@ -416,7 +418,8 @@ class ExtractedEntityNode(Node):
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity") # fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
@@ -507,7 +510,8 @@ class MemorySummaryNode(Node):
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties # ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field( original_statement_id: Optional[str] = Field(

View File

@@ -227,7 +227,8 @@ class EmbeddingGenerator:
# 打印前几个嵌入向量的维度 # 打印前几个嵌入向量的维度
for i in range(min(5, len(embeddings))): for i in range(min(5, len(embeddings))):
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") print(f"实体 '{entity_texts[i]}' "
f"嵌入向量维度: {len(embeddings[i])}")
# 将嵌入向量赋值给实体 # 将嵌入向量赋值给实体
for ent, emb in zip(entity_refs, embeddings): for ent, emb in zip(entity_refs, embeddings):

View File

@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# Entity Merge Query # Entity Merge Query
MERGE_ENTITIES = """ MERGE_ENTITIES = """
MATCH (canonical:ExtractedEntity {id: $canonical_id}) MATCH (canonical:ExtractedEntity {id: $canonical_id})
@@ -829,7 +828,6 @@ neo4j_query_all = """
other as entity2 other as entity2
""" """
'''针对当前节点下扩长的句子,实体和总结''' '''针对当前节点下扩长的句子,实体和总结'''
Memory_Timeline_ExtractedEntity = """ Memory_Timeline_ExtractedEntity = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
@@ -1060,7 +1058,6 @@ Graph_Node_query = """
""" """
# ============================================================ # ============================================================
# Community 节点 & BELONGS_TO_COMMUNITY 边 # Community 节点 & BELONGS_TO_COMMUNITY 边
# ============================================================ # ============================================================

View File

@@ -8,9 +8,6 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
# ============================================================================ # ============================================================================
# 从 json_schema.py 迁移的 Schema # 从 json_schema.py 迁移的 Schema
# ============================================================================ # ============================================================================
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
class ConflictResultSchema(BaseModel): class ConflictResultSchema(BaseModel):
"""Schema for the conflict result data in the reflexion_data.json file.""" """Schema for the conflict result data in the reflexion_data.json file."""
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") data: List[BaseDataSchema] = Field(...,
description="The conflict memory data. Only contains conflicting records when conflict is True.")
conflict: bool = Field(..., description="Whether the memory is in conflict.") conflict: bool = Field(..., description="Whether the memory is in conflict.")
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
memory_verify: Optional[MemoryVerifySchema] = Field(None,
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_data(cls, v): def _normalize_data(cls, v):
@@ -105,12 +105,15 @@ class ChangeRecordSchema(BaseModel):
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
) )
class ResolvedSchema(BaseModel): class ResolvedSchema(BaseModel):
"""Schema for the resolved memory data in the reflexion_data""" """Schema for the resolved memory data in the reflexion_data"""
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
change: Optional[List[ChangeRecordSchema]] = Field(None,
description="List of detailed change records with IDs and field information.")
class SingleReflexionResultSchema(BaseModel): class SingleReflexionResultSchema(BaseModel):
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
type: str = Field("reflexion_result", description="The type identifier.") type: str = Field("reflexion_result", description="The type identifier.")
class ReflexionResultSchema(BaseModel): class ReflexionResultSchema(BaseModel):
"""Schema for the complete reflexion result data - a list of individual conflict resolutions.""" """Schema for the complete reflexion result data - a list of individual conflict resolutions."""
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") results: List[SingleReflexionResultSchema] = Field(...,
description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_resolved(cls, v): def _normalize_resolved(cls, v):
@@ -148,8 +153,8 @@ class ReflexionResultSchema(BaseModel):
class ConfigKey(BaseModel): # 配置参数键模型 class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
user_id: str = Field("user_id", description="用户标识(字符串)") user_id: str | None = Field(default=None, description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
# Allowed chunking strategies (extendable later) # Allowed chunking strategies (extendable later)
@@ -241,6 +246,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致") reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
@@ -387,6 +394,7 @@ def fail(
time=time or _now_ms(), time=time or _now_ms(),
) )
class GenerateCacheRequest(BaseModel): class GenerateCacheRequest(BaseModel):
"""缓存生成请求模型""" """缓存生成请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -472,7 +480,8 @@ class ForgettingStatsResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布") node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据每天取最后一次执行") recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)") timestamp: int = Field(..., description="统计时间(时间戳)")

View File

@@ -11,9 +11,11 @@ import time
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import ( from app.core.memory.analytics.hot_memory_tags import (
get_hot_memory_tags,
get_raw_tags_from_db, get_raw_tags_from_db,
filter_tags_with_llm, filter_tags_with_llm,
) )
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
) )
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__) logger = get_logger(__name__)
config_logger = get_config_logger() config_logger = get_config_logger()
@@ -241,7 +241,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
except (ValueError, TypeError): except (ValueError, TypeError):
config_id_old = None config_id_old = None
if config_id_old: if config_id_old:
memory_config = config_id_old memory_config = config_id_old
else: else:
@@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list) return self._convert_timestamps_to_format(data_list)
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
""" """
流式执行试运行,产生 SSE 格式的进度事件 流式执行试运行,产生 SSE 格式的进度事件
@@ -344,11 +342,13 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 关联了本体场景,优先使用 custom_text # 关联了本体场景,优先使用 custom_text
if hasattr(payload, 'custom_text') and payload.custom_text: if hasattr(payload, 'custom_text') and payload.custom_text:
dialogue_text = payload.custom_text.strip() dialogue_text = payload.custom_text.strip()
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") logger.info(
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
else: else:
# 如果没有提供 custom_text回退到 dialogue_text # 如果没有提供 custom_text回退到 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") logger.info(
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
else: else:
# 没有关联本体场景,使用 dialogue_text # 没有关联本体场景,使用 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
@@ -360,7 +360,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
# 步骤 2: 创建进度回调函数捕获管线进度 # 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件 # 使用队列在回调和生成器之间传递进度事件
progress_queue: asyncio.Queue = asyncio.Queue() progress_queue: asyncio.Queue = asyncio.Queue()
@@ -382,7 +381,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
try: try:
from app.services.pilot_run_service import run_pilot_extraction from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") logger.info(
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction( await run_pilot_extraction(
memory_config=memory_config, memory_config=memory_config,
dialogue_text=dialogue_text, dialogue_text=dialogue_text,
@@ -483,7 +483,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000) "time": int(time.time() * 1000)
}) })
async def _compute_ontology_coverage( async def _compute_ontology_coverage(
self, self,
extracted_result: Dict[str, Any], extracted_result: Dict[str, Any],
@@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD) # Ensure env for connector (e.g., NEO4J_PASSWORD)
load_dotenv()
_neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
@@ -701,6 +698,7 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
) )
return result return result
async def analytics_hot_memory_tags( async def analytics_hot_memory_tags(
db: Session, db: Session,
current_user: User, current_user: User,
@@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
return data return data

View File

@@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, def write_message_task(
self,
end_user_id: str,
message: list[dict],
config_id: str | int,
storage_type: str,
user_rag_memory_id: str, user_rag_memory_id: str,
language: str = "zh") -> Dict[str, Any]: language: str = "zh"
) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
end_user_id: Group ID for the memory agent (also used as end_user_id) end_user_id: Group ID for the memory agent (also used as end_user_id)
@@ -1105,14 +1111,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try: try:
with get_db_context() as db: with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db) actual_config_id = resolve_config_id(config_id, db)
print(100 * '-') logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} "
print(actual_config_id) f"(type: {type(actual_config_id).__name__})")
print(100 * '-')
logger.info(
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e: except (ValueError, AttributeError) as e:
logger.error( logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} "
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") f"(type: {type(config_id).__name__}), error: {e}")
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}", "error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1151,8 +1154,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info( logger.info(f"[CELERY WRITE] Task completed successfully "
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try: try:
@@ -1167,7 +1170,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
) )
except Exception as _e: except Exception as _e:
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"result": result, "result": result,
@@ -2749,7 +2751,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
llm_model_id=llm_model_id, llm_model_id=llm_model_id,
) )
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}") logger.info(
f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}")
await engine.full_clustering(end_user_id) await engine.full_clustering(end_user_id)
initialized += 1 initialized += 1
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
@@ -2772,12 +2775,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
} }
try: try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
loop = set_asyncio_event_loop() loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time result["elapsed_time"] = time.time() - start_time