Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
@@ -92,6 +95,157 @@ class EndUserRepository:
|
||||
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据ID获取终端用户(用于缓存操作)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
Optional[EndUser]: 终端用户对象,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.first()
|
||||
)
|
||||
if end_user:
|
||||
db_logger.debug(f"成功查询到终端用户 {end_user_id}")
|
||||
else:
|
||||
db_logger.debug(f"未找到终端用户 {end_user_id}")
|
||||
return end_user
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询终端用户 {end_user_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_memory_insight(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
insight: str
|
||||
) -> bool:
|
||||
"""更新记忆洞察缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
insight: 记忆洞察内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.memory_insight: insight,
|
||||
EndUser.memory_insight_updated_at: datetime.datetime.now()
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的记忆洞察缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新记忆洞察缓存")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的记忆洞察缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_user_summary(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
summary: str
|
||||
) -> bool:
|
||||
"""更新用户摘要缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
summary: 用户摘要内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.user_summary: summary,
|
||||
EndUser.user_summary_updated_at: datetime.datetime.now()
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的用户摘要缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新用户摘要缓存")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取工作空间的所有终端用户
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
List[EndUser]: 终端用户列表
|
||||
"""
|
||||
try:
|
||||
end_users = (
|
||||
self.db.query(EndUser)
|
||||
.join(App, EndUser.app_id == App.id)
|
||||
.filter(App.workspace_id == workspace_id)
|
||||
.all()
|
||||
)
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
||||
return end_users
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下的终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_active_workspaces(self) -> List[uuid.UUID]:
|
||||
"""获取所有活动工作空间的ID
|
||||
|
||||
Returns:
|
||||
List[uuid.UUID]: 活动工作空间ID列表
|
||||
"""
|
||||
try:
|
||||
workspace_ids = (
|
||||
self.db.query(Workspace.id)
|
||||
.filter(Workspace.is_active)
|
||||
.all()
|
||||
)
|
||||
# 提取ID(查询返回的是元组列表)
|
||||
workspace_id_list = [workspace_id[0] for workspace_id in workspace_ids]
|
||||
db_logger.info(f"成功查询到 {len(workspace_id_list)} 个活动工作空间")
|
||||
return workspace_id_list
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询活动工作空间时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||
repo = EndUserRepository(db)
|
||||
@@ -138,4 +292,30 @@ def update_end_user_other_name(
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
# 新增的缓存操作函数(保持与类方法一致的接口)
|
||||
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据ID获取终端用户(用于缓存操作)"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_by_id(end_user_id)
|
||||
|
||||
def update_memory_insight(db: Session, end_user_id: uuid.UUID, insight: str) -> bool:
|
||||
"""更新记忆洞察缓存"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.update_memory_insight(end_user_id, insight)
|
||||
|
||||
def update_user_summary(db: Session, end_user_id: uuid.UUID, summary: str) -> bool:
|
||||
"""更新用户摘要缓存"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.update_user_summary(end_user_id, summary)
|
||||
|
||||
def get_all_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取工作空间的所有终端用户"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_all_by_workspace(workspace_id)
|
||||
|
||||
def get_all_active_workspaces(db: Session) -> List[uuid.UUID]:
|
||||
"""获取所有活动工作空间的ID"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_all_active_workspaces()
|
||||
|
||||
@@ -783,7 +783,9 @@ neo4j_query_part = """
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
rel.predicate as predicate,
|
||||
rel.statement as relationship,
|
||||
rel.statement_id as relationship_statement_id,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
@@ -799,7 +801,9 @@ neo4j_query_all = """
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
rel.predicate as predicate,
|
||||
rel.statement as relationship,
|
||||
rel.statement_id as relationship_statement_id,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
@@ -67,11 +67,81 @@ async def update_neo4j_data(neo4j_dict_data, update_databases):
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
|
||||
"""
|
||||
Update Neo4j data based on query criteria and update parameters
|
||||
|
||||
Args:
|
||||
neo4j_dict_data: find
|
||||
update_databases: update
|
||||
"""
|
||||
try:
|
||||
# 构建WHERE条件
|
||||
where_conditions = []
|
||||
params = {}
|
||||
|
||||
for key, value in neo4j_dict_data.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
where_conditions.append(f"r.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
|
||||
# 构建SET条件
|
||||
set_conditions = []
|
||||
for key, value in update_databases.items():
|
||||
if value is not None:
|
||||
param_name = f"update_{key}"
|
||||
set_conditions.append(f"r.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
set_clause = ", ".join(set_conditions)
|
||||
|
||||
if not set_clause:
|
||||
print("警告: 没有需要更新的字段")
|
||||
return False
|
||||
|
||||
# 构建Cypher查询
|
||||
cypher_query = f"""
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE {where_clause}
|
||||
SET {set_clause}
|
||||
RETURN count(r) as updated_count, collect(type(r)) as relation_types
|
||||
"""
|
||||
|
||||
print(f"\n执行Cypher查询: {cypher_query}")
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 执行更新
|
||||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||||
|
||||
if result:
|
||||
updated_count = result[0].get('updated_count', 0)
|
||||
updated_names = result[0].get('updated_names', [])
|
||||
print(f"成功更新 {updated_count} 个节点")
|
||||
if updated_names:
|
||||
print(f"更新的实体名称: {updated_names}")
|
||||
return updated_count > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
def map_field_names(data_dict):
|
||||
mapped_dict = {}
|
||||
has_name_field = False
|
||||
|
||||
# 辅助函数:提取值(如果是数组则取最后一个值,否则直接返回)
|
||||
def extract_value(value):
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
# 如果是数组 [old_value, new_value],取新值(最后一个)
|
||||
return value[-1]
|
||||
return value
|
||||
|
||||
# 第一遍:检查是否有name相关字段
|
||||
for key, value in data_dict.items():
|
||||
if key in ['name', 'entity2.name', 'entity1.name']:
|
||||
@@ -82,22 +152,25 @@ def map_field_names(data_dict):
|
||||
|
||||
# 第二遍:根据规则映射和过滤字段
|
||||
for key, value in data_dict.items():
|
||||
# 提取实际值(处理数组格式)
|
||||
actual_value = extract_value(value)
|
||||
|
||||
if key == 'entity2.name' or key == 'entity2_name':
|
||||
# 将 entity2.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
mapped_dict['name'] = actual_value
|
||||
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||||
elif key == 'entity1.name' or key == 'entity1_name':
|
||||
# 将 entity1.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
mapped_dict['name'] = actual_value
|
||||
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||||
elif key == 'entity1.description':
|
||||
# 将 entity1.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
mapped_dict['description'] = actual_value
|
||||
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||||
elif key == 'entity2.description':
|
||||
# 将 entity2.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
mapped_dict['description'] = actual_value
|
||||
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||||
elif key == 'relationship_type':
|
||||
# 跳过relationship_type字段
|
||||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||||
@@ -109,8 +182,8 @@ def map_field_names(data_dict):
|
||||
continue
|
||||
else:
|
||||
# 如果没有name字段,保留entity1_name
|
||||
mapped_dict[key] = value
|
||||
print(f"字段保留: {key}")
|
||||
mapped_dict[key] = actual_value
|
||||
print(f"字段保留: {key} (值: {value} -> {actual_value})")
|
||||
elif key == 'entity2_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity2_name
|
||||
@@ -122,7 +195,11 @@ def map_field_names(data_dict):
|
||||
continue
|
||||
elif '.' not in key:
|
||||
# 不包含点号的其他字段直接保留
|
||||
mapped_dict[key] = value
|
||||
mapped_dict[key] = actual_value
|
||||
if isinstance(value, list):
|
||||
print(f"字段保留: {key} (数组值: {value} -> {actual_value})")
|
||||
else:
|
||||
print(f"字段保留: {key}")
|
||||
else:
|
||||
# 其他包含点号的字段跳过并警告
|
||||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||||
@@ -139,89 +216,57 @@ async def neo4j_data(solved_data):
|
||||
"""
|
||||
success_count = 0
|
||||
|
||||
ori_entity = {}
|
||||
updata_entity = {}
|
||||
ori_edge = {}
|
||||
updata_edge = {}
|
||||
ori_expired_at={}
|
||||
updat_expired_at={}
|
||||
for i in solved_data:
|
||||
neo4j_dict_data = {}
|
||||
update_databases = {}
|
||||
results = i['results']
|
||||
for data in results:
|
||||
resolved = data.get('resolved')
|
||||
if not resolved:
|
||||
print("跳过:resolved为None")
|
||||
databasets = i['data']
|
||||
for key, values in databasets.items():
|
||||
if str(values)=='NONE':
|
||||
continue
|
||||
if isinstance(values, list):
|
||||
if key == 'description':
|
||||
ori_entity[key] = values[0]
|
||||
updata_entity[key] = values[1]
|
||||
if key == 'entity2_name' or key == 'entity1_name':
|
||||
key = 'name'
|
||||
ori_entity[key] = values[0]
|
||||
updata_entity[key] = values[1]
|
||||
ori_expired_at[key] = values[0]
|
||||
if key == 'statement':
|
||||
ori_edge[key] = values[0]
|
||||
updata_edge[key] = values[1]
|
||||
if key=='expired_at':
|
||||
updat_expired_at[key] = values[1]
|
||||
|
||||
try:
|
||||
change_list = resolved.get('change', [])
|
||||
except (AttributeError, TypeError):
|
||||
change_list = []
|
||||
elif key == 'statement_id':
|
||||
ori_edge[key] = values
|
||||
updata_edge[key] = values
|
||||
|
||||
if change_list == []:
|
||||
print("跳过:change_list为空")
|
||||
continue
|
||||
ori_entity[key] = values
|
||||
updata_entity[key] = values
|
||||
|
||||
if change_list and len(change_list) > 0:
|
||||
change = change_list[0]
|
||||
print(f"change: {change}")
|
||||
field_data = change.get('field', [])
|
||||
print(f"field_data: {field_data}")
|
||||
print(f"field_data type: {type(field_data)}")
|
||||
|
||||
# 字段名映射和过滤函数
|
||||
ori_expired_at[key] = values
|
||||
|
||||
|
||||
# 处理field数据,可能是字典或列表
|
||||
if isinstance(field_data, dict):
|
||||
# 如果是字典,映射字段名后更新
|
||||
mapped_data = map_field_names(field_data)
|
||||
update_databases.update(mapped_data)
|
||||
elif isinstance(field_data, list):
|
||||
# 如果是列表,遍历每个字典并更新
|
||||
for field_item in field_data:
|
||||
if isinstance(field_item, dict):
|
||||
mapped_item = map_field_names(field_item)
|
||||
update_databases.update(mapped_item)
|
||||
else:
|
||||
print(f"警告: field_item不是字典: {field_item}")
|
||||
else:
|
||||
print(f"警告: field_data类型不支持: {type(field_data)}")
|
||||
|
||||
if 'entity1_name' in data:
|
||||
data['name'] = data.pop('entity1_name')
|
||||
if 'entity2_name' in data:
|
||||
data.pop('entity2_name', None)
|
||||
|
||||
resolved_memory = resolved.get('resolved_memory', {})
|
||||
|
||||
entity2 = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
entity2 = resolved_memory.get('entity2')
|
||||
|
||||
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
|
||||
stat_id = resolved.get('original_memory_id')
|
||||
# 安全地获取description
|
||||
statement_id = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
statement_id = resolved_memory.get('statement_id')
|
||||
|
||||
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
|
||||
if statement_id and 'id' not in neo4j_dict_data:
|
||||
neo4j_dict_data['id'] = stat_id
|
||||
neo4j_dict_data['statement_id'] = statement_id
|
||||
else:
|
||||
# 处理original_memory_id,它可能是字符串或字典
|
||||
try:
|
||||
for key, value in resolved_memory.items():
|
||||
if key == 'statement_id':
|
||||
neo4j_dict_data['statement_id'] = value
|
||||
if key == 'description':
|
||||
neo4j_dict_data['description'] = value
|
||||
except AttributeError:
|
||||
neo4j_dict_data=[]
|
||||
|
||||
print(neo4j_dict_data)
|
||||
print(update_databases)
|
||||
if neo4j_dict_data!=[]:
|
||||
await update_neo4j_data(neo4j_dict_data, update_databases)
|
||||
success_count += 1
|
||||
print(ori_entity)
|
||||
print(updata_entity)
|
||||
print(100*'-')
|
||||
print(ori_edge)
|
||||
print(updata_edge)
|
||||
expired_at_ = updat_expired_at.get('expired_at', None)
|
||||
if expired_at_ is not None:
|
||||
await update_neo4j_data(ori_expired_at, updat_expired_at)
|
||||
success_count += 1
|
||||
if ori_entity != updata_entity:
|
||||
await update_neo4j_data(ori_entity, updata_entity)
|
||||
success_count += 1
|
||||
if ori_edge != updata_edge:
|
||||
await update_neo4j_data_edge(ori_edge, updata_edge)
|
||||
success_count += 1
|
||||
|
||||
return success_count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user