style(memory): Some code style optimizations

This commit is contained in:
Eternity
2026-03-20 18:22:20 +08:00
parent e8ae46b286
commit c17a2dad2d
8 changed files with 296 additions and 292 deletions

View File

@@ -11,9 +11,11 @@ import time
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import (
get_hot_memory_tags,
get_raw_tags_from_db,
filter_tags_with_llm,
)
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
)
from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
@@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector()
class MemoryStorageService:
"""Service for memory storage operations"""
def __init__(self):
logger.info("MemoryStorageService initialized")
async def get_storage_info(self) -> dict:
"""
Example wrapper method - retrieves storage information
@@ -59,17 +59,17 @@ class MemoryStorageService:
Storage information dictionary
"""
logger.info("Getting storage info ")
# Empty wrapper - implement your logic here
result = {
"status": "active",
"message": "This is an example wrapper"
}
return result
class DataConfigService: # 数据配置服务类PostgreSQL
return result
class DataConfigService: # 数据配置服务类PostgreSQL
"""Service layer for config params CRUD.
使用 SQLAlchemy ORM 进行数据库操作。
@@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return data_list
# --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig
@@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return None
# --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID
success = MemoryConfigRepository.delete(self.db, key.config_id)
if not success:
raise ValueError("未找到配置")
return {"affected": 1}
# --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
config = MemoryConfigRepository.update(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": 1}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
config = MemoryConfigRepository.update_extracted(self.db, update)
if not config:
raise ValueError("未找到配置")
@@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
# --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
return result
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
@@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
except (ValueError, TypeError):
config_id_old = None
if config_id_old:
memory_config=config_id_old
memory_config = config_id_old
else:
memory_config=config.config_id
memory_config = config.config_id
config_dict = {
"config_id": memory_config,
"config_name": config.config_name,
@@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list)
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
"""
流式执行试运行,产生 SSE 格式的进度事件
@@ -311,14 +309,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"""
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
try:
# 发出初始进度事件
yield format_sse_message("starting", {
"message": "开始试运行...",
"time": int(time.time() * 1000)
})
# 步骤 1: 配置加载和验证(数据库优先)
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
@@ -344,27 +342,28 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 关联了本体场景,优先使用 custom_text
if hasattr(payload, 'custom_text') and payload.custom_text:
dialogue_text = payload.custom_text.strip()
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
logger.info(
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
else:
# 如果没有提供 custom_text回退到 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
logger.info(
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
else:
# 没有关联本体场景,使用 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}")
# 验证最终使用的文本不为空
if not dialogue_text:
raise ValueError("试运行模式必须提供有效的文本内容dialogue_text 或 custom_text")
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件
progress_queue: asyncio.Queue = asyncio.Queue()
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
"""
进度回调函数,将进度事件放入队列
@@ -375,14 +374,15 @@ class DataConfigService: # 数据配置服务类PostgreSQL
data: 可选的结果数据(用于传递节点执行结果)
"""
await progress_queue.put((stage, message, data))
# 步骤 3: 在后台任务中执行管线
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
logger.info(
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text,
@@ -391,60 +391,60 @@ class DataConfigService: # 数据配置服务类PostgreSQL
language=language,
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
# 标记管线完成
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
except Exception as e:
# 将异常放入队列
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
# 启动后台任务
pipeline_task = asyncio.create_task(run_pipeline())
# 步骤 4: 从队列中读取进度事件并发出
while True:
try:
# 等待进度事件,设置超时以检测客户端断开
stage, message, data = await asyncio.wait_for(
progress_queue.get(),
progress_queue.get(),
timeout=0.5
)
# 检查特殊标记
if stage == "__PIPELINE_COMPLETE__":
break
elif stage == "__PIPELINE_ERROR__":
raise RuntimeError(message)
# 构建进度事件数据
progress_data = {
"message": message,
"time": int(time.time() * 1000)
}
# 如果有结果数据,添加到事件中
if data:
progress_data["data"] = data
# 发出进度事件,使用 stage 作为事件类型
yield format_sse_message(stage, progress_data)
except TimeoutError:
# 超时,继续等待(这允许检测客户端断开)
continue
# 等待管线任务完成
await pipeline_task
# 步骤 5: 读取提取结果
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
# 步骤 6: 计算本体覆盖率并合并到结果中
result_data = {
"config_id": cid,
@@ -460,15 +460,15 @@ class DataConfigService: # 数据配置服务类PostgreSQL
result_data["ontology_coverage"] = ontology_coverage
except Exception as cov_err:
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
yield format_sse_message("result", result_data)
# 步骤 7: 发出完成事件
yield format_sse_message("done", {
"message": "试运行完成",
"time": int(time.time() * 1000)
})
except asyncio.CancelledError:
# 客户端断开连接
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
@@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000)
})
async def _compute_ontology_coverage(
self,
extracted_result: Dict[str, Any],
memory_config,
self,
extracted_result: Dict[str, Any],
memory_config,
) -> Optional[Dict[str, Any]]:
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
@@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD)
load_dotenv()
_neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
@@ -664,7 +661,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
"total": 0,
"total": 0,
"distribution": [
{"type": "dialogue", "count": 0},
{"type": "chunk", "count": 0},
@@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
)
return result
async def analytics_hot_memory_tags(
db: Session,
current_user: User,
limit: int = 10
db: Session,
current_user: User,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取热门记忆标签按数量排序并返回前N个
@@ -721,27 +719,27 @@ async def analytics_hot_memory_tags(
from app.services.memory_dashboard_service import get_workspace_end_users
# 使用 asyncio.to_thread 避免阻塞事件循环
end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user)
if not end_users:
return []
# 步骤1: 收集所有用户的原始标签不调用LLM
connector = Neo4jConnector()
try:
all_raw_tags = []
for end_user in end_users:
raw_tags = await get_raw_tags_from_db(
connector,
str(end_user.id),
limit=raw_limit,
connector,
str(end_user.id),
limit=raw_limit,
by_user=False
)
if raw_tags:
all_raw_tags.extend(raw_tags)
if not all_raw_tags:
return []
# 步骤2: 聚合相同标签的频率
tag_frequency_map = {}
for tag_name, frequency in all_raw_tags:
@@ -749,36 +747,36 @@ async def analytics_hot_memory_tags(
tag_frequency_map[tag_name] += frequency
else:
tag_frequency_map[tag_name] = frequency
# 步骤3: 按频率降序排序取前raw_limit个
sorted_tags = sorted(
tag_frequency_map.items(),
key=lambda x: x[1],
tag_frequency_map.items(),
key=lambda x: x[1],
reverse=True
)[:raw_limit]
if not sorted_tags:
return []
# 步骤4: 只调用一次LLM进行筛选
tag_names = [tag for tag, _ in sorted_tags]
# 使用第一个用户的end_user_id来获取LLM配置
# 因为同一工作空间下的用户应该使用相同的配置
first_end_user_id = str(end_users[0].id)
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
# 步骤5: 根据LLM筛选结果构建最终列表保留频率
final_tags = []
for tag, freq in sorted_tags:
if tag in filtered_tag_names:
final_tags.append((tag, freq))
# 步骤6: 只返回前limit个
top_tags = final_tags[:limit]
return [{"name": t, "frequency": f} for t, f in top_tags]
finally:
await connector.close()
@@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
source = "log"
total = (
stats.get("chunk_count", 0)
+ stats.get("statements_count", 0)
+ stats.get("triplet_entities_count", 0)
+ stats.get("triplet_relations_count", 0)
+ stats.get("temporal_count", 0)
stats.get("chunk_count", 0)
+ stats.get("statements_count", 0)
+ stats.get("triplet_entities_count", 0)
+ stats.get("triplet_relations_count", 0)
+ stats.get("temporal_count", 0)
)
# 计算"最新一次活动多久前"(仅日志来源时有效)
@@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
return data