Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -16,7 +16,6 @@ from app.models.data_config_model import DataConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
@@ -29,37 +28,37 @@ db_logger = get_db_logger()
|
||||
# 获取配置专用日志器
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
TABLE_NAME = "data_config"
|
||||
class DataConfigRepository:
|
||||
"""数据配置Repository
|
||||
|
||||
|
||||
提供data_config表的数据访问方法,包括:
|
||||
- SQLAlchemy ORM 数据库操作
|
||||
- Neo4j Cypher查询常量
|
||||
"""
|
||||
|
||||
|
||||
# ==================== Neo4j Cypher 查询常量 ====================
|
||||
|
||||
|
||||
# Dialogue count by group
|
||||
SEARCH_FOR_DIALOGUE = """
|
||||
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# Chunk count by group
|
||||
SEARCH_FOR_CHUNK = """
|
||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# Statement count by group
|
||||
SEARCH_FOR_STATEMENT = """
|
||||
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# ExtractedEntity count by group
|
||||
SEARCH_FOR_ENTITY = """
|
||||
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# All counts by label and total
|
||||
SEARCH_FOR_ALL = """
|
||||
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
|
||||
@@ -72,7 +71,7 @@ class DataConfigRepository:
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
@@ -88,7 +87,7 @@ class DataConfigRepository:
|
||||
n.user_id AS user_id,
|
||||
n.id AS id
|
||||
"""
|
||||
|
||||
|
||||
# Edges between extracted entities within group/app/user
|
||||
SEARCH_FOR_EDGES = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
@@ -104,7 +103,7 @@ class DataConfigRepository:
|
||||
r.statement_id AS statement_id,
|
||||
r.statement AS statement
|
||||
"""
|
||||
|
||||
|
||||
# Entity graph within group (source node, edge, target node)
|
||||
SEARCH_FOR_ENTITY_GRAPH = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
@@ -137,22 +136,106 @@ class DataConfigRepository:
|
||||
id: m.id
|
||||
} AS targetNode
|
||||
"""
|
||||
|
||||
|
||||
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
|
||||
|
||||
@staticmethod
|
||||
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
|
||||
"""构建反思配置更新语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
**kwargs: 反思配置参数
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
|
||||
|
||||
key_where = "config_id = :config_id"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": config_id,
|
||||
}
|
||||
|
||||
# 反思配置字段映射
|
||||
mapping = {
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
"reflection_model_id": "reflection_model_id",
|
||||
"memory_verify": "memory_verify",
|
||||
"quality_assessment": "quality_assessment",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
if api_field in kwargs and kwargs[api_field] is not None:
|
||||
set_fields.append(f"{db_col} = :{api_field}")
|
||||
params[api_field] = kwargs[api_field]
|
||||
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
||||
f"reflection_model_id, memory_verify, quality_assessment, user_id "
|
||||
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
|
||||
)
|
||||
params = {"config_id": config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
||||
f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at "
|
||||
f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC"
|
||||
)
|
||||
params = {"workspace_id": workspace_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
|
||||
"""创建数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
params: 配置参数创建模型
|
||||
|
||||
|
||||
Returns:
|
||||
DataConfig: 创建的配置对象
|
||||
"""
|
||||
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = DataConfig(
|
||||
config_name=params.config_name,
|
||||
@@ -164,37 +247,37 @@ class DataConfigRepository:
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
|
||||
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
|
||||
"""更新基础配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.config_name is not None:
|
||||
@@ -203,44 +286,44 @@ class DataConfigRepository:
|
||||
if update.config_desc is not None:
|
||||
db_config.config_desc = update.config_desc
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
|
||||
"""更新记忆萃取引擎配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 萃取配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
@@ -270,50 +353,50 @@ class DataConfigRepository:
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
|
||||
|
||||
has_update = False
|
||||
for api_field, db_field in field_mapping.items():
|
||||
value = getattr(update, api_field, None)
|
||||
if value is not None:
|
||||
setattr(db_config, db_field, value)
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"萃取配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
|
||||
"""更新遗忘引擎配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 遗忘配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.lambda_time is not None:
|
||||
@@ -325,40 +408,40 @@ class DataConfigRepository:
|
||||
if update.offset is not None:
|
||||
db_config.offset = update.offset
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取萃取配置,通过主键查询某条配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
|
||||
result = {
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
@@ -381,62 +464,62 @@ class DataConfigRepository:
|
||||
"reflexion_range": db_config.reflexion_range,
|
||||
"baseline": db_config.baseline,
|
||||
}
|
||||
|
||||
|
||||
db_logger.debug(f"萃取配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取遗忘配置,通过主键查询某条配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 遗忘配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
|
||||
result = {
|
||||
"lambda_time": db_config.lambda_time,
|
||||
"lambda_mem": db_config.lambda_mem,
|
||||
"offset": db_config.offset,
|
||||
}
|
||||
|
||||
|
||||
db_logger.debug(f"遗忘配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
|
||||
"""根据ID获取数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
|
||||
else:
|
||||
@@ -571,56 +654,56 @@ class DataConfigRepository:
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID,用于过滤查询结果
|
||||
|
||||
|
||||
Returns:
|
||||
List[DataConfig]: 配置列表
|
||||
"""
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
query = db.query(DataConfig)
|
||||
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(DataConfig.workspace_id == workspace_id)
|
||||
|
||||
|
||||
configs = query.order_by(desc(DataConfig.updated_at)).all()
|
||||
|
||||
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||
return configs
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, config_id: int) -> bool:
|
||||
"""删除数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,配置不存在返回False
|
||||
"""
|
||||
db_logger.debug(f"删除数据配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={config_id}")
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_config)
|
||||
db.commit()
|
||||
|
||||
|
||||
db_logger.info(f"数据配置删除成功: config_id={config_id}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
|
||||
|
||||
@@ -115,7 +115,9 @@ def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Kn
|
||||
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first()
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.name == name,
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1).first()
|
||||
if knowledge:
|
||||
db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})")
|
||||
else:
|
||||
|
||||
@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -32,7 +32,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -117,13 +117,21 @@ class ModelConfigRepository:
|
||||
filters.append(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
# 支持多个 type 值(使用 IN 查询)
|
||||
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
|
||||
if query.type:
|
||||
filters.append(ModelConfig.type.in_(query.type))
|
||||
type_values = list(query.type)
|
||||
# 如果包含 chat 或 llm,则同时包含两者
|
||||
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
|
||||
if ModelType.CHAT not in type_values:
|
||||
type_values.append(ModelType.CHAT)
|
||||
if ModelType.LLM not in type_values:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
@@ -183,12 +191,12 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelConfig.is_active == True)
|
||||
query = query.filter(ModelConfig.is_active)
|
||||
|
||||
models = query.order_by(ModelConfig.name).all()
|
||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||
@@ -285,7 +293,7 @@ class ModelConfigRepository:
|
||||
try:
|
||||
# 总数统计
|
||||
total_models = db.query(ModelConfig).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
|
||||
|
||||
# 按类型统计
|
||||
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
|
||||
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
|
||||
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelApiKey.is_active == True)
|
||||
query = query.filter(ModelApiKey.is_active)
|
||||
|
||||
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
|
||||
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
|
||||
|
||||
@@ -100,7 +100,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
|
||||
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
||||
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
|
||||
# 添加情绪字段处理
|
||||
"emotion_type": statement.emotion_type,
|
||||
"emotion_intensity": statement.emotion_intensity,
|
||||
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
|
||||
"emotion_subject": statement.emotion_subject,
|
||||
"emotion_target": statement.emotion_target
|
||||
}
|
||||
flattened_statements.append(flattened_statement)
|
||||
|
||||
|
||||
@@ -20,20 +20,25 @@ UNWIND $statements AS statement
|
||||
MERGE (s:Statement {id: statement.id})
|
||||
SET s += {
|
||||
id: statement.id,
|
||||
run_id: statement.run_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
group_id: statement.group_id,
|
||||
user_id: statement.user_id,
|
||||
apply_id: statement.apply_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
run_id: statement.run_id,
|
||||
stmt_type: statement.stmt_type,
|
||||
statement: statement.statement,
|
||||
emotion_intensity: statement.emotion_intensity,
|
||||
emotion_target: statement.emotion_target,
|
||||
emotion_subject: statement.emotion_subject,
|
||||
emotion_type: statement.emotion_type,
|
||||
emotion_keywords: statement.emotion_keywords,
|
||||
temporal_info: statement.temporal_info,
|
||||
created_at: statement.created_at,
|
||||
expired_at: statement.expired_at,
|
||||
stmt_type: statement.stmt_type,
|
||||
temporal_info: statement.temporal_info,
|
||||
relevence_info: statement.relevence_info,
|
||||
statement: statement.statement,
|
||||
valid_at: statement.valid_at,
|
||||
invalid_at: statement.invalid_at,
|
||||
statement_embedding: statement.statement_embedding
|
||||
statement_embedding: statement.statement_embedding,
|
||||
relevence_info: statement.relevence_info
|
||||
}
|
||||
RETURN s.id AS uuid
|
||||
"""
|
||||
@@ -746,3 +751,57 @@ DETACH DELETE losing
|
||||
|
||||
RETURN count(losing) as deleted
|
||||
"""
|
||||
|
||||
neo4j_statement_part = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id,
|
||||
n.created_at as statement_created_at
|
||||
|
||||
'''
|
||||
neo4j_statement_all = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id
|
||||
|
||||
'''
|
||||
neo4j_query_part = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
m.name as entity1_name,
|
||||
m.description as description,
|
||||
m.statement_id as statement_id,
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
neo4j_query_all = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
m.name as entity1_name,
|
||||
m.description as description,
|
||||
m.statement_id as statement_id,
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
|
||||
|
||||
246
api/app/repositories/neo4j/emotion_repository.py
Normal file
246
api/app/repositories/neo4j/emotion_repository.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""情绪数据仓储模块
|
||||
|
||||
本模块提供情绪数据的查询功能,用于情绪分析和统计。
|
||||
|
||||
Classes:
|
||||
EmotionRepository: 情绪数据仓储,提供情绪标签、词云、健康指数等查询方法
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class EmotionRepository:
|
||||
"""情绪数据仓储
|
||||
|
||||
提供情绪数据的查询和统计功能,包括:
|
||||
- 情绪标签统计
|
||||
- 情绪词云数据
|
||||
- 时间范围内的情绪数据查询
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化情绪数据仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
self.connector = connector
|
||||
logger.info("情绪数据仓储初始化完成")
|
||||
|
||||
async def get_emotion_tags(
|
||||
self,
|
||||
group_id: str,
|
||||
emotion_type: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取情绪标签统计
|
||||
|
||||
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)
|
||||
start_date: 可选的开始日期(ISO格式字符串)
|
||||
end_date: 可选的结束日期(ISO格式字符串)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 情绪标签列表,每个包含:
|
||||
- emotion_type: 情绪类型
|
||||
- count: 该类型的数量
|
||||
- percentage: 占比百分比
|
||||
- avg_intensity: 平均强度
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if emotion_type:
|
||||
where_clauses.append("s.emotion_type = $emotion_type")
|
||||
params["emotion_type"] = emotion_type
|
||||
|
||||
if start_date:
|
||||
where_clauses.append("s.created_at >= $start_date")
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
where_clauses.append("s.created_at <= $end_date")
|
||||
params["end_date"] = end_date
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
# 优化的 Cypher 查询:使用索引,减少中间结果
|
||||
query = f"""
|
||||
MATCH (s:Statement)
|
||||
WHERE {where_str}
|
||||
WITH s.emotion_type as emotion_type,
|
||||
count(*) as count,
|
||||
avg(s.emotion_intensity) as avg_intensity
|
||||
WITH collect({{emotion_type: emotion_type, count: count, avg_intensity: avg_intensity}}) as results,
|
||||
sum(count) as total_count
|
||||
UNWIND results as result
|
||||
RETURN result.emotion_type as emotion_type,
|
||||
result.count as count,
|
||||
toFloat(result.count) / total_count * 100 as percentage,
|
||||
result.avg_intensity as avg_intensity
|
||||
ORDER BY count DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
formatted_results = [
|
||||
{
|
||||
"emotion_type": record["emotion_type"],
|
||||
"count": record["count"],
|
||||
"percentage": round(record["percentage"], 2),
|
||||
"avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0
|
||||
}
|
||||
for record in results
|
||||
]
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"查询情绪标签失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_emotion_wordcloud(
|
||||
self,
|
||||
group_id: str,
|
||||
emotion_type: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取情绪词云数据
|
||||
|
||||
查询情绪关键词及其频率,用于生成词云可视化。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
emotion_type: 可选的情绪类型过滤
|
||||
limit: 返回关键词的最大数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 关键词列表,每个包含:
|
||||
- keyword: 关键词
|
||||
- frequency: 出现频率
|
||||
- emotion_type: 关联的情绪类型
|
||||
- avg_intensity: 平均强度
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if emotion_type:
|
||||
where_clauses.append("s.emotion_type = $emotion_type")
|
||||
params["emotion_type"] = emotion_type
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
# 优化的 Cypher 查询:使用索引,减少不必要的计算
|
||||
query = f"""
|
||||
MATCH (s:Statement)
|
||||
WHERE {where_str}
|
||||
UNWIND s.emotion_keywords as keyword
|
||||
WITH keyword,
|
||||
s.emotion_type as emotion_type,
|
||||
count(*) as frequency,
|
||||
avg(s.emotion_intensity) as avg_intensity
|
||||
WHERE keyword IS NOT NULL AND keyword <> ''
|
||||
RETURN keyword,
|
||||
frequency,
|
||||
emotion_type,
|
||||
avg_intensity
|
||||
ORDER BY frequency DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
formatted_results = [
|
||||
{
|
||||
"keyword": record["keyword"],
|
||||
"frequency": record["frequency"],
|
||||
"emotion_type": record["emotion_type"],
|
||||
"avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0
|
||||
}
|
||||
for record in results
|
||||
]
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"查询情绪词云失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_emotions_in_range(
|
||||
self,
|
||||
group_id: str,
|
||||
time_range: str = "30d"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取时间范围内的情绪数据
|
||||
|
||||
查询指定时间范围内的所有情绪数据,用于健康指数计算。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
time_range: 时间范围(7d/30d/90d)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 情绪数据列表,每个包含:
|
||||
- emotion_type: 情绪类型
|
||||
- emotion_intensity: 情绪强度
|
||||
- created_at: 创建时间
|
||||
- statement_id: 陈述句ID
|
||||
"""
|
||||
# 解析时间范围
|
||||
days_map = {"7d": 7, "30d": 30, "90d": 90}
|
||||
days = days_map.get(time_range, 30)
|
||||
|
||||
# 计算起始日期(使用字符串比较,避免时区问题)
|
||||
start_date = (datetime.now() - timedelta(days=days)).isoformat()
|
||||
|
||||
# 优化的 Cypher 查询:使用字符串比较避免时区问题
|
||||
query = """
|
||||
MATCH (s:Statement)
|
||||
WHERE s.group_id = $group_id
|
||||
AND s.emotion_type IS NOT NULL
|
||||
AND s.created_at >= $start_date
|
||||
RETURN s.id as statement_id,
|
||||
s.emotion_type as emotion_type,
|
||||
s.emotion_intensity as emotion_intensity,
|
||||
s.created_at as created_at
|
||||
ORDER BY s.created_at ASC
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
start_date=start_date
|
||||
)
|
||||
formatted_results = [
|
||||
{
|
||||
"statement_id": record["statement_id"],
|
||||
"emotion_type": record["emotion_type"],
|
||||
"emotion_intensity": record["emotion_intensity"],
|
||||
"created_at": record["created_at"].isoformat() if hasattr(record["created_at"], "isoformat") else str(record["created_at"])
|
||||
}
|
||||
for record in results
|
||||
]
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"查询时间范围情绪数据失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
227
api/app/repositories/neo4j/neo4j_update.py
Normal file
227
api/app/repositories/neo4j/neo4j_update.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from app.repositories import Neo4jConnector
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
async def update_neo4j_data(neo4j_dict_data, update_databases):
|
||||
"""
|
||||
Update Neo4j data based on query criteria and update parameters
|
||||
|
||||
Args:
|
||||
neo4j_dict_data: find
|
||||
update_databases: update
|
||||
"""
|
||||
try:
|
||||
# 构建WHERE条件
|
||||
where_conditions = []
|
||||
params = {}
|
||||
|
||||
for key, value in neo4j_dict_data.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
where_conditions.append(f"e.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
|
||||
# 构建SET条件
|
||||
set_conditions = []
|
||||
for key, value in update_databases.items():
|
||||
if value is not None:
|
||||
param_name = f"update_{key}"
|
||||
set_conditions.append(f"e.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
set_clause = ", ".join(set_conditions)
|
||||
|
||||
if not set_clause:
|
||||
print("警告: 没有需要更新的字段")
|
||||
return False
|
||||
|
||||
# 构建Cypher查询
|
||||
cypher_query = f"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE {where_clause}
|
||||
SET {set_clause}
|
||||
RETURN count(e) as updated_count, collect(e.name) as updated_names
|
||||
"""
|
||||
|
||||
print(f"\n执行Cypher查询: {cypher_query}")
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 执行更新
|
||||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||||
|
||||
if result:
|
||||
updated_count = result[0].get('updated_count', 0)
|
||||
updated_names = result[0].get('updated_names', [])
|
||||
print(f"成功更新 {updated_count} 个节点")
|
||||
if updated_names:
|
||||
print(f"更新的实体名称: {updated_names}")
|
||||
return updated_count > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def map_field_names(data_dict):
|
||||
mapped_dict = {}
|
||||
has_name_field = False
|
||||
|
||||
# 第一遍:检查是否有name相关字段
|
||||
for key, value in data_dict.items():
|
||||
if key in ['name', 'entity2.name', 'entity1.name']:
|
||||
has_name_field = True
|
||||
break
|
||||
|
||||
print(f"字段检查: has_name_field = {has_name_field}")
|
||||
|
||||
# 第二遍:根据规则映射和过滤字段
|
||||
for key, value in data_dict.items():
|
||||
if key == 'entity2.name' or key == 'entity2_name':
|
||||
# 将 entity2.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
elif key == 'entity1.name' or key == 'entity1_name':
|
||||
# 将 entity1.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
elif key == 'entity1.description':
|
||||
# 将 entity1.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
elif key == 'entity2.description':
|
||||
# 将 entity2.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
elif key == 'relationship_type':
|
||||
# 跳过relationship_type字段
|
||||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||||
continue
|
||||
elif key == 'entity1_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity1_name
|
||||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||||
continue
|
||||
else:
|
||||
# 如果没有name字段,保留entity1_name
|
||||
mapped_dict[key] = value
|
||||
print(f"字段保留: {key}")
|
||||
elif key == 'entity2_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity2_name
|
||||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||||
continue
|
||||
else:
|
||||
# 即使没有name字段,也不使用entity2_name(根据需求)
|
||||
print(f"字段过滤: 跳过不推荐的字段 '{key}'")
|
||||
continue
|
||||
elif '.' not in key:
|
||||
# 不包含点号的其他字段直接保留
|
||||
mapped_dict[key] = value
|
||||
else:
|
||||
# 其他包含点号的字段跳过并警告
|
||||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||||
|
||||
print(f"字段映射结果: {mapped_dict}")
|
||||
return mapped_dict
|
||||
async def neo4j_data(solved_data):
|
||||
"""
|
||||
Process the resolved data and update the Neo4j database
|
||||
Args:
|
||||
Solved_data: Solution Data List
|
||||
Returns:
|
||||
Int: Number of successfully updated records
|
||||
"""
|
||||
success_count = 0
|
||||
|
||||
for i in solved_data:
|
||||
neo4j_dict_data = {}
|
||||
update_databases = {}
|
||||
results = i['results']
|
||||
for data in results:
|
||||
resolved = data.get('resolved')
|
||||
if not resolved:
|
||||
print("跳过:resolved为None")
|
||||
continue
|
||||
|
||||
try:
|
||||
change_list = resolved.get('change', [])
|
||||
except (AttributeError, TypeError):
|
||||
change_list = []
|
||||
|
||||
if change_list == []:
|
||||
print("跳过:change_list为空")
|
||||
continue
|
||||
|
||||
if change_list and len(change_list) > 0:
|
||||
change = change_list[0]
|
||||
print(f"change: {change}")
|
||||
field_data = change.get('field', [])
|
||||
print(f"field_data: {field_data}")
|
||||
print(f"field_data type: {type(field_data)}")
|
||||
|
||||
# 字段名映射和过滤函数
|
||||
|
||||
|
||||
# 处理field数据,可能是字典或列表
|
||||
if isinstance(field_data, dict):
|
||||
# 如果是字典,映射字段名后更新
|
||||
mapped_data = map_field_names(field_data)
|
||||
update_databases.update(mapped_data)
|
||||
elif isinstance(field_data, list):
|
||||
# 如果是列表,遍历每个字典并更新
|
||||
for field_item in field_data:
|
||||
if isinstance(field_item, dict):
|
||||
mapped_item = map_field_names(field_item)
|
||||
update_databases.update(mapped_item)
|
||||
else:
|
||||
print(f"警告: field_item不是字典: {field_item}")
|
||||
else:
|
||||
print(f"警告: field_data类型不支持: {type(field_data)}")
|
||||
|
||||
if 'entity1_name' in data:
|
||||
data['name'] = data.pop('entity1_name')
|
||||
if 'entity2_name' in data:
|
||||
data.pop('entity2_name', None)
|
||||
|
||||
resolved_memory = resolved.get('resolved_memory', {})
|
||||
|
||||
entity2 = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
entity2 = resolved_memory.get('entity2')
|
||||
|
||||
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
|
||||
stat_id = resolved.get('original_memory_id')
|
||||
# 安全地获取description
|
||||
statement_id = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
statement_id = resolved_memory.get('statement_id')
|
||||
|
||||
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
|
||||
if statement_id and 'id' not in neo4j_dict_data:
|
||||
neo4j_dict_data['id'] = stat_id
|
||||
neo4j_dict_data['statement_id'] = statement_id
|
||||
else:
|
||||
# 处理original_memory_id,它可能是字符串或字典
|
||||
try:
|
||||
for key, value in resolved_memory.items():
|
||||
if key == 'statement_id':
|
||||
neo4j_dict_data['statement_id'] = value
|
||||
if key == 'description':
|
||||
neo4j_dict_data['description'] = value
|
||||
except AttributeError:
|
||||
neo4j_dict_data=[]
|
||||
|
||||
print(neo4j_dict_data)
|
||||
print(update_databases)
|
||||
if neo4j_dict_data!=[]:
|
||||
await update_neo4j_data(neo4j_dict_data, update_databases)
|
||||
success_count += 1
|
||||
|
||||
return success_count
|
||||
|
||||
@@ -58,11 +58,22 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
||||
|
||||
# 处理temporal_info字段
|
||||
if isinstance(n.get('temporal_info'), dict):
|
||||
if isinstance(n.get('temporal_info'), str):
|
||||
# 从字符串转换为枚举值
|
||||
n['temporal_info'] = TemporalInfo(n['temporal_info'])
|
||||
elif isinstance(n.get('temporal_info'), dict):
|
||||
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
|
||||
elif not n.get('temporal_info'):
|
||||
# 如果没有temporal_info,创建一个默认的
|
||||
n['temporal_info'] = TemporalInfo()
|
||||
n['temporal_info'] = TemporalInfo.STATIC
|
||||
|
||||
# 处理情绪字段 - 映射 Neo4j 节点属性到 StatementNode 模型
|
||||
# 处理空值情况,确保字段存在
|
||||
n['emotion_type'] = n.get('emotion_type')
|
||||
n['emotion_intensity'] = n.get('emotion_intensity')
|
||||
n['emotion_keywords'] = n.get('emotion_keywords', [])
|
||||
n['emotion_subject'] = n.get('emotion_subject')
|
||||
n['emotion_target'] = n.get('emotion_target')
|
||||
|
||||
return StatementNode(**n)
|
||||
|
||||
|
||||
124
api/app/repositories/prompt_optimizer_repository.py
Normal file
124
api/app/repositories/prompt_optimizer_repository.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.prompt_optimizer_model import (
|
||||
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
||||
)
|
||||
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class PromptOptimizerSessionRepository:
|
||||
"""Repository for managing prompt optimization sessions and session history."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID
|
||||
) -> PromptOptimizerSession:
|
||||
"""
|
||||
Create a new prompt optimization session for a user and app.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
||||
user_id (uuid.UUID): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
PromptOptimizerSession: The newly created session object.
|
||||
"""
|
||||
db_logger.debug(f"Create prompt optimization session: tenant_id={tenant_id}, user_id={user_id}")
|
||||
try:
|
||||
session = PromptOptimizerSession(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
|
||||
return session
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_session_history(
|
||||
self,
|
||||
session_id: uuid.UUID,
|
||||
user_id: uuid.UUID
|
||||
) -> list[type[PromptOptimizerSessionHistory]]:
|
||||
"""
|
||||
Retrieve all message history of a specific prompt optimization session.
|
||||
|
||||
Args:
|
||||
session_id (uuid.UUID): The unique identifier of the session.
|
||||
user_id (uuid.UUID): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
list[PromptOptimizerSessionHistory]: A list of session history records
|
||||
ordered by creation time ascending.
|
||||
"""
|
||||
db_logger.debug(f"Get prompt optimization session history: "
|
||||
f"user_id={user_id}, session_id={session_id}")
|
||||
|
||||
try:
|
||||
# First get the internal session ID from the session list table
|
||||
session = self.db.query(PromptOptimizerSession).filter(
|
||||
PromptOptimizerSession.id == session_id,
|
||||
PromptOptimizerSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
return []
|
||||
|
||||
history = self.db.query(PromptOptimizerSessionHistory).filter(
|
||||
PromptOptimizerSessionHistory.session_id == session.id,
|
||||
PromptOptimizerSessionHistory.user_id == user_id
|
||||
).order_by(PromptOptimizerSessionHistory.created_at.asc()).all()
|
||||
return history
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error retrieving prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def create_message(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
session_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
role: RoleType,
|
||||
content: str,
|
||||
) -> PromptOptimizerSessionHistory:
|
||||
"""
|
||||
Create a new message in the session history.
|
||||
|
||||
This method is a placeholder for future implementation.
|
||||
"""
|
||||
try:
|
||||
# Get the session to ensure it exists and belongs to the user
|
||||
session = self.db.query(PromptOptimizerSession).filter(
|
||||
PromptOptimizerSession.id == session_id,
|
||||
PromptOptimizerSession.user_id == user_id,
|
||||
PromptOptimizerSession.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
db_logger.error(f"Session {session_id} not found for user {user_id}")
|
||||
raise ValueError(f"Session {session_id} not found for user {user_id}")
|
||||
|
||||
message = PromptOptimizerSessionHistory(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session.id,
|
||||
user_id=user_id,
|
||||
role=role.value,
|
||||
content=content,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
return message
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
Reference in New Issue
Block a user