Fix/memory bug fix (#171)

This commit is contained in:
lixinyue11
2026-01-26 11:53:34 +08:00
committed by GitHub
parent 714c624dc6
commit 3601737869
119 changed files with 1711 additions and 1695 deletions

View File

@@ -19,7 +19,7 @@ from app.core.memory.analytics.hot_memory_tags import (
)
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError
from app.schemas.memory_storage_schema import (
@@ -129,7 +129,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
config = DataConfigRepository.create(self.db, params)
config = MemoryConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
@@ -146,20 +146,20 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID
success = DataConfigRepository.delete(self.db, key.config_id)
success = MemoryConfigRepository.delete(self.db, key.config_id)
if not success:
raise ValueError("未找到配置")
return {"affected": 1}
# --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
config = DataConfigRepository.update(self.db, update)
config = MemoryConfigRepository.update(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": 1}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
config = DataConfigRepository.update_extracted(self.db, update)
config = MemoryConfigRepository.update_extracted(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": 1}
@@ -170,14 +170,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
return result
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
configs = DataConfigRepository.get_all(self.db, workspace_id)
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
# 将 ORM 对象转换为字典列表
data_list = []
@@ -187,7 +187,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"config_name": config.config_name,
"config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
"group_id": config.group_id,
"end_user_id": config.end_user_id,
"user_id": config.user_id,
"apply_id": config.apply_id,
"llm_id": config.llm_id,
@@ -395,8 +395,8 @@ _neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DIALOGUE,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_DIALOGUE,
end_user_id=end_user_id,
)
data = {"search_for": "dialogue", "num": result[0]["num"]}
return data
@@ -404,8 +404,8 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_CHUNK,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_CHUNK,
end_user_id=end_user_id,
)
data = {"search_for": "chunk", "num": result[0]["num"]}
return data
@@ -413,8 +413,8 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_STATEMENT,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_STATEMENT,
end_user_id=end_user_id,
)
data = {"search_for": "statement", "num": result[0]["num"]}
return data
@@ -422,8 +422,8 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_ENTITY,
end_user_id=end_user_id,
)
data = {"search_for": "entity", "num": result[0]["num"]}
return data
@@ -431,8 +431,8 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_ALL,
end_user_id=end_user_id,
)
# 检查结果是否为空或长度不足
@@ -466,8 +466,8 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。
"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_ALL,
end_user_id=end_user_id,
)
# 检查结果是否为空或长度不足
@@ -497,21 +497,19 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DETIALS,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_DETIALS,
end_user_id=end_user_id,
)
return result
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_EDGES,
group_id=end_user_id,
MemoryConfigRepository.SEARCH_FOR_EDGES,
end_user_id=end_user_id,
)
return result
async def analytics_hot_memory_tags(
db: Session,
current_user: User,
@@ -574,7 +572,7 @@ async def analytics_hot_memory_tags(
# 步骤4: 只调用一次LLM进行筛选
tag_names = [tag for tag, _ in sorted_tags]
# 使用第一个用户的group_id来获取LLM配置
# 使用第一个用户的end_user_id来获取LLM配置
# 因为同一工作空间下的用户应该使用相同的配置
first_end_user_id = str(end_users[0].id)
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)