config_config替换成memory_config

This commit is contained in:
lixinyue
2026-01-22 18:43:22 +08:00
parent f3f9211c9c
commit 8db4f914d8
21 changed files with 158 additions and 201 deletions

View File

@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success from app.core.response_utils import success
from app.dependencies import get_current_user from app.dependencies import get_current_user
@@ -32,11 +33,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel): class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型""" """情绪配置查询请求模型"""
config_id: int = Field(..., description="配置ID") config_id: UUID = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel): class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型""" """情绪配置更新请求模型"""
config_id: int = Field(..., description="配置ID") config_id: UUID = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取") emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID") emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词") emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse) @router.get("/read_config", response_model=ApiResponse)
def get_emotion_config( def get_emotion_config(
config_id: int = Query(..., description="配置ID"), config_id: UUID = Query(..., description="配置ID"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):

View File

@@ -11,6 +11,7 @@
""" """
from typing import Optional from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -128,7 +129,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse) @router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config( async def read_forgetting_config(
config_id: int, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import time import time
import uuid import uuid
from uuid import UUID
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import ( from app.core.memory.storage_services.reflection_engine.self_reflexion import (
@@ -156,7 +157,7 @@ async def start_workspace_reflection(
@router.get("/reflection/configs") @router.get("/reflection/configs")
async def start_reflection_configs( async def start_reflection_configs(
config_id: int, config_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
@@ -191,7 +192,7 @@ async def start_reflection_configs(
@router.get("/reflection/run") @router.get("/reflection/run")
async def reflection_run( async def reflection_run(
config_id: int, config_id: UUID,
language_type: str = Header(default="zh", alias="X-Language-Type"), language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),

View File

@@ -1,5 +1,6 @@
import os import os
from typing import Optional from typing import Optional
from uuid import UUID
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
@@ -160,7 +161,7 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: str, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
@@ -232,7 +233,7 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: str, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:

View File

@@ -11,6 +11,7 @@ Functions:
import logging import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
@@ -61,7 +62,7 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
def load_actr_config_from_db( def load_actr_config_from_db(
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
从数据库加载 ACT-R 配置参数 从数据库加载 ACT-R 配置参数
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
def create_actr_calculator_from_config( def create_actr_calculator_from_config(
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> ACTRCalculator: ) -> ACTRCalculator:
""" """
从数据库配置创建 ACTRCalculator 实例 从数据库配置创建 ACTRCalculator 实例
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
ValueError: 如果指定的 config_id 不存在 ValueError: 如果指定的 config_id 不存在
Examples: Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
""" """
# 加载配置 # 加载配置
config = load_actr_config_from_db(db, config_id) config = load_actr_config_from_db(db, config_id)

View File

@@ -16,6 +16,7 @@ Classes:
import logging import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from uuid import UUID
from datetime import datetime from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
@@ -69,7 +70,7 @@ class ForgettingScheduler:
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100, max_merge_batch_size: int = 100,
min_days_since_access: int = 30, min_days_since_access: int = 30,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """

View File

@@ -13,6 +13,7 @@ Classes:
import logging import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from uuid import UUID
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -176,7 +177,7 @@ class ForgettingStrategy:
self, self,
statement_node: Dict[str, Any], statement_node: Dict[str, Any],
entity_node: Dict[str, Any], entity_node: Dict[str, Any],
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> str: ) -> str:
""" """
@@ -462,7 +463,7 @@ class ForgettingStrategy:
statement_text: str, statement_text: str,
entity_name: str, entity_name: str,
entity_type: str, entity_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> str: ) -> str:
""" """
@@ -527,7 +528,7 @@ class ForgettingStrategy:
statement_text, entity_name, entity_type statement_text, entity_name, entity_type
) )
async def _get_llm_client(self, db, config_id: int): async def _get_llm_client(self, db, config_id: UUID):
""" """
从数据库获取 LLM 客户端 从数据库获取 LLM 客户端

View File

@@ -26,7 +26,7 @@ logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str, def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]: config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID.""" """Parse model ID from string or UUID."""
if model_id is None: if model_id is None:
return None return None
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
model_type: str, model_type: str,
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Validate that a model exists and is active. """Validate that a model exists and is active.
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
required: bool = False, required: bool = False,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]: ) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status. """Validate and resolve a model ID, checking existence and active status.
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
def validate_embedding_model( def validate_embedding_model(
config_id: int, config_id: UUID,
embedding_id: Union[str, UUID, None], embedding_id: Union[str, UUID, None],
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
@@ -256,7 +256,7 @@ def validate_embedding_model(
def validate_llm_model( def validate_llm_model(
config_id: int, config_id: UUID,
llm_id: Union[str, UUID, None], llm_id: Union[str, UUID, None],
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,

View File

@@ -1,4 +1,5 @@
import uuid import uuid
from uuid import UUID
from pydantic import Field from pydantic import Field
from typing import Literal from typing import Literal
@@ -11,7 +12,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
... ...
) )
config_id: int = Field( config_id: UUID = Field(
... ...
) )
@@ -26,6 +27,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
... ...
) )
config_id: int = Field( config_id: UUID = Field(
... ...
) )

View File

@@ -9,7 +9,7 @@ class MemoryConfig(Base):
__tablename__ = "memory_config" __tablename__ = "memory_config"
# 主键 # 主键
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID") config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
# 基本信息 # 基本信息
config_name = Column(String, nullable=False, comment="配置名称") config_name = Column(String, nullable=False, comment="配置名称")

View File

@@ -9,6 +9,7 @@ Classes:
""" """
import uuid import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger from app.core.logging_config import get_config_logger, get_db_logger
@@ -107,7 +108,7 @@ class MemoryConfigRepository:
@staticmethod @staticmethod
def update_reflection_config( def update_reflection_config(
db: Session, db: Session,
config_id: int, config_id: uuid.UUID,
enable_self_reflexion: bool, enable_self_reflexion: bool,
iteration_period: str, iteration_period: str,
reflexion_range: str, reflexion_range: str,
@@ -151,7 +152,7 @@ class MemoryConfigRepository:
return memory_config_obj return memory_config_obj
@staticmethod @staticmethod
def query_reflection_config_by_id(db: Session, config_id: int) -> MemoryConfig: def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数) """构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args: Args:
@@ -222,6 +223,7 @@ class MemoryConfigRepository:
try: try:
db_config = MemoryConfig( db_config = MemoryConfig(
config_id=uuid.uuid4(),
config_name=params.config_name, config_name=params.config_name,
config_desc=params.config_desc, config_desc=params.config_desc,
workspace_id=params.workspace_id, workspace_id=params.workspace_id,
@@ -408,7 +410,7 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]: def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置 """获取萃取配置,通过主键查询某条配置
Args: Args:
@@ -457,7 +459,7 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]: def get_forget_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取遗忘配置,通过主键查询某条配置 """获取遗忘配置,通过主键查询某条配置
Args: Args:
@@ -489,7 +491,7 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_by_id(db: Session, config_id: int) -> Optional[MemoryConfig]: def get_by_id(db: Session, config_id: uuid.UUID) -> Optional[MemoryConfig]:
"""根据ID获取记忆配置 """根据ID获取记忆配置
Args: Args:
@@ -513,7 +515,7 @@ class MemoryConfigRepository:
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}") db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
raise raise
@staticmethod @staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]: def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
"""Get memory config and its associated workspace information """Get memory config and its associated workspace information
Args: Args:
@@ -664,7 +666,7 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def delete(db: Session, config_id: int) -> bool: def delete(db: Session, config_id: uuid.UUID) -> bool:
"""删除记忆配置 """删除记忆配置
Args: Args:

View File

@@ -1,6 +1,7 @@
"""情绪分析相关的请求和响应模型""" """情绪分析相关的请求和响应模型"""
from typing import Optional from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class EmotionTagsRequest(BaseModel): class EmotionTagsRequest(BaseModel):
@@ -30,7 +31,7 @@ class EmotionHealthRequest(BaseModel):
class EmotionSuggestionsRequest(BaseModel): class EmotionSuggestionsRequest(BaseModel):
"""获取个性化情绪建议请求""" """获取个性化情绪建议请求"""
end_user_id: str = Field(..., description="组ID") end_user_id: str = Field(..., description="组ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型") config_id: Optional[UUID] = Field(None, description="配置ID用于指定LLM模型")
class EmotionGenerateSuggestionsRequest(BaseModel): class EmotionGenerateSuggestionsRequest(BaseModel):

View File

@@ -15,3 +15,6 @@ class Write_UserInput(BaseModel):
messages: list[dict] messages: list[dict]
end_user_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class End_User_Information(BaseModel):
end_user_name: str # 这是要更新的用户名
id: str # 宿主ID用于匹配条件

View File

@@ -35,7 +35,7 @@ class ConfigurationError(Exception):
def __init__( def __init__(
self, self,
message: str, message: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
): ):
@@ -72,7 +72,7 @@ class WorkspaceNotFoundError(ConfigurationError):
def __init__( def __init__(
self, self,
workspace_id: UUID, workspace_id: UUID,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
message: Optional[str] = None, message: Optional[str] = None,
): ):
if message is None: if message is None:
@@ -89,7 +89,7 @@ class ModelNotFoundError(ConfigurationError):
self, self,
model_id: Union[str, UUID], model_id: Union[str, UUID],
model_type: str, model_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
message: Optional[str] = None, message: Optional[str] = None,
): ):
@@ -112,7 +112,7 @@ class ModelInactiveError(ConfigurationError):
model_id: Union[str, UUID], model_id: Union[str, UUID],
model_name: str, model_name: str,
model_type: str, model_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
message: Optional[str] = None, message: Optional[str] = None,
): ):
@@ -136,7 +136,7 @@ class InvalidConfigError(ConfigurationError):
message: str, message: str,
field_name: Optional[str] = None, field_name: Optional[str] = None,
invalid_value: Optional[Any] = None, invalid_value: Optional[Any] = None,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
): ):
context = {} context = {}
@@ -155,7 +155,7 @@ class InvalidConfigError(ConfigurationError):
class MemoryConfigValidation(BaseModel): class MemoryConfigValidation(BaseModel):
"""Pydantic model for validating memory configuration data from database.""" """Pydantic model for validating memory configuration data from database."""
config_id: int = Field(..., gt=0, description="Configuration ID must be positive") config_id: UUID = Field(..., description="Configuration ID (UUID)")
config_name: str = Field(..., min_length=1, max_length=255) config_name: str = Field(..., min_length=1, max_length=255)
workspace_id: UUID = Field(..., description="Workspace UUID") workspace_id: UUID = Field(..., description="Workspace UUID")
workspace_name: str = Field(..., min_length=1, max_length=255) workspace_name: str = Field(..., min_length=1, max_length=255)
@@ -275,7 +275,7 @@ class ModelValidation(BaseModel):
def validate_memory_config_data( def validate_memory_config_data(
config_data: Dict[str, Any], config_id: Optional[int] = None config_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> MemoryConfigValidation: ) -> MemoryConfigValidation:
"""Validate memory configuration data using Pydantic model.""" """Validate memory configuration data using Pydantic model."""
try: try:
@@ -302,7 +302,7 @@ def validate_memory_config_data(
def validate_workspace_data( def validate_workspace_data(
workspace_data: Dict[str, Any], config_id: Optional[int] = None workspace_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> WorkspaceValidation: ) -> WorkspaceValidation:
"""Validate workspace data using Pydantic model.""" """Validate workspace data using Pydantic model."""
try: try:
@@ -331,7 +331,7 @@ def validate_workspace_data(
def validate_model_data( def validate_model_data(
model_data: Dict[str, Any], config_id: Optional[int] = None model_data: Dict[str, Any], config_id: Optional[UUID] = None
) -> ModelValidation: ) -> ModelValidation:
"""Validate model data using Pydantic model.""" """Validate model data using Pydantic model."""
try: try:
@@ -364,7 +364,7 @@ def validate_model_data(
class MemoryConfig: class MemoryConfig:
"""Immutable memory configuration loaded from database.""" """Immutable memory configuration loaded from database."""
config_id: int config_id: UUID
config_name: str config_name: str
workspace_id: UUID workspace_id: UUID
workspace_name: str workspace_name: str

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from uuid import UUID
from enum import Enum from enum import Enum
@@ -9,7 +10,7 @@ class OptimizationStrategy(str, Enum):
ACCURACY_FIRST = "accuracy_first" ACCURACY_FIRST = "accuracy_first"
BALANCED = "balanced" BALANCED = "balanced"
class Memory_Reflection(BaseModel): class Memory_Reflection(BaseModel):
config_id: Optional[int] = None config_id: Optional[UUID] = None
reflection_enabled: bool reflection_enabled: bool
reflection_period_in_hours: str reflection_period_in_hours: str
reflexion_range: Optional[str] = "partial" reflexion_range: Optional[str] = "partial"

View File

@@ -159,7 +159,7 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row # Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型 class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field("config_id", description="配置唯一标识(字符串") config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID")
user_id: str = Field("user_id", description="用户标识(字符串)") user_id: str = Field("user_id", description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
@@ -250,17 +250,17 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id: int = Field("配置ID", description="配置ID字符串") config_id: uuid.UUID = Field("配置ID", description="配置IDUUID")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
config_name: str = Field("配置名称", description="配置名称(字符串)") config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
@@ -327,14 +327,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型 # 遗忘引擎配置参数更新模型
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5") lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型 class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id: int = Field(..., description="配置ID唯一") config_id: uuid.UUID = Field(..., description="配置ID唯一")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -342,7 +342,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型 class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Optional[int] = None config_id: Optional[uuid.UUID] = None
user_id: Optional[str] = None user_id: Optional[str] = None
apply_id: Optional[str] = None apply_id: Optional[str] = None
@@ -418,7 +418,7 @@ class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型""" """遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID") config_id: uuid.UUID = Field(..., description="配置ID")
decay_constant: float = Field(..., description="衰减常数 d") decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数") lambda_time: float = Field(..., description="时间衰减参数")
lambda_mem: float = Field(..., description="记忆衰减参数") lambda_mem: float = Field(..., description="记忆衰减参数")
@@ -436,7 +436,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型""" """遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID") config_id: uuid.UUID = Field(..., description="配置ID")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -511,7 +511,7 @@ class ForgettingCurveRequest(BaseModel):
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1") importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天") days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Optional[int] = Field(None, description="配置ID可选如果为None则使用默认配置") config_id: Optional[uuid.UUID] = Field(None, description="配置ID可选如果为None则使用默认配置")
class ForgettingCurveResponse(BaseModel): class ForgettingCurveResponse(BaseModel):

View File

@@ -8,6 +8,8 @@ Classes:
""" """
from typing import Dict, Any from typing import Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
@@ -37,7 +39,7 @@ class EmotionConfigService:
self.db = db self.db = db
logger.info("情绪配置服务初始化完成") logger.info("情绪配置服务初始化完成")
def get_emotion_config(self, config_id: int) -> Dict[str, Any]: def get_emotion_config(self, config_id: UUID) -> Dict[str, Any]:
"""获取情绪引擎配置 """获取情绪引擎配置
查询指定配置ID的情绪相关配置字段。 查询指定配置ID的情绪相关配置字段。
@@ -144,7 +146,7 @@ class EmotionConfigService:
def update_emotion_config( def update_emotion_config(
self, self,
config_id: int, config_id: UUID,
config_data: Dict[str, Any] config_data: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""更新情绪引擎配置 """更新情绪引擎配置

View File

@@ -9,6 +9,7 @@ import os
import re import re
import time import time
import uuid import uuid
from uuid import UUID
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
import redis import redis
@@ -266,7 +267,7 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
""" """
Process write operation with config_id Process write operation with config_id
@@ -319,85 +320,52 @@ class MemoryAgentService:
raise ValueError(error_msg) raise ValueError(error_msg)
async with make_write_graph() as graph: try:
config = {"configurable": {"thread_id": end_user_id}} if storage_type == "rag":
# Convert structured messages to LangChain messages # For RAG storage, convert messages to single string
langchain_messages = [] message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
for msg in messages: result = await write_rag(end_user_id, message_text, user_rag_memory_id)
if msg['role'] == 'user': return result
langchain_messages.append(HumanMessage(content=msg['content'])) else:
elif msg['role'] == 'assistant': async with make_write_graph() as graph:
langchain_messages.append(AIMessage(content=msg['content'])) config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = { initial_state = {
"messages": langchain_messages, "messages": langchain_messages,
"end_user_id": end_user_id, "end_user_id": end_user_id,
"memory_config": memory_config "memory_config": memory_config
} }
# 获取节点更新信息 # 获取节点更新信息
async for update_event in graph.astream( async for update_event in graph.astream(
initial_state, initial_state,
stream_mode="updates", stream_mode="updates",
config=config config=config
): ):
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name: if 'save_neo4j' == node_name:
massages = node_data massages = node_data
print(massages) massagesstatus = massages.get('write_result')['status']
massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result')
contents = massages.get('write_result') # Convert messages back to string for logging
# Convert messages back to string for logging message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) except Exception as e:
# Ensure proper error handling and logging
# try: error_msg = f"Write operation failed: {str(e)}"
# if storage_type == "rag": logger.error(error_msg)
# # For RAG storage, convert messages to single string if audit_logger:
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) duration = time.time() - start_time
# result = await write_rag(end_user_id, message_text, user_rag_memory_id) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
# return result raise ValueError(error_msg)
# else:
# async with make_write_graph() as graph:
# config = {"configurable": {"thread_id": end_user_id}}
# # Convert structured messages to LangChain messages
# langchain_messages = []
# for msg in messages:
# if msg['role'] == 'user':
# langchain_messages.append(HumanMessage(content=msg['content']))
# elif msg['role'] == 'assistant':
# langchain_messages.append(AIMessage(content=msg['content']))
#
# # 初始状态 - 包含所有必要字段
# initial_state = {
# "messages": langchain_messages,
# "end_user_id": end_user_id,
# "memory_config": memory_config
# }
#
# # 获取节点更新信息
# async for update_event in graph.astream(
# initial_state,
# stream_mode="updates",
# config=config
# ):
# for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
# massagesstatus = massages.get('write_result')['status']
# contents = massages.get('write_result')
# # Convert messages back to string for logging
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# except Exception as e:
# # Ensure proper error handling and logging
# error_msg = f"Write operation failed: {str(e)}"
# logger.error(error_msg)
# if audit_logger:
# duration = time.time() - start_time
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
# raise ValueError(error_msg)
@@ -408,7 +376,7 @@ class MemoryAgentService:
message: str, message: str,
history: List[Dict], history: List[Dict],
search_switch: str, search_switch: str,
config_id: Optional[str], config_id: Optional[UUID],
db: Session, db: Session,
storage_type: str, storage_type: str,
user_rag_memory_id: str) -> Dict: user_rag_memory_id: str) -> Dict:
@@ -685,7 +653,7 @@ class MemoryAgentService:
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters. Updated to eliminate global variables in favor of explicit parameters.
@@ -716,7 +684,7 @@ class MemoryAgentService:
retrieve_info: str, retrieve_info: str,
history: List[Dict], history: List[Dict],
query: str, query: str,
config_id: str, config_id: UUID,
db: Session db: Session
) -> str: ) -> str:
""" """

View File

@@ -23,53 +23,12 @@ from app.schemas.memory_config_schema import (
ModelNotFoundError, ModelNotFoundError,
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
logger = get_logger(__name__) logger = get_logger(__name__)
config_logger = get_config_logger() config_logger = get_config_logger()
def _validate_config_id(config_id):
"""Validate configuration ID format."""
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
field_name="config_id",
invalid_value=config_id,
)
if isinstance(config_id, int):
if config_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
return config_id
if isinstance(config_id, str):
try:
parsed_id = int(config_id.strip())
if parsed_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {parsed_id}",
field_name="config_id",
invalid_value=config_id,
)
return parsed_id
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
invalid_value=config_id,
)
raise InvalidConfigError(
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
class MemoryConfigService: class MemoryConfigService:
""" """
Centralized service for memory configuration loading and validation. Centralized service for memory configuration loading and validation.
@@ -93,14 +52,14 @@ class MemoryConfigService:
def load_memory_config( def load_memory_config(
self, self,
config_id: int, config_id: UUID,
service_name: str = "MemoryConfigService", service_name: str = "MemoryConfigService",
) -> MemoryConfig: ) -> MemoryConfig:
""" """
Load memory configuration from database by config_id. Load memory configuration from database by config_id.
Args: Args:
config_id: Configuration ID from database config_id: Configuration ID (UUID) from database
service_name: Name of the calling service (for logging purposes) service_name: Name of the calling service (for logging purposes)
Returns: Returns:
@@ -116,18 +75,34 @@ class MemoryConfigService:
extra={ extra={
"operation": "load_memory_config", "operation": "load_memory_config",
"service": service_name, "service": service_name,
"config_id": config_id, "config_id": str(config_id),
}, },
) )
logger.info(f"Loading memory configuration from database: config_id={config_id}") logger.info(f"Loading memory configuration from database: config_id={config_id}")
try: try:
validated_config_id = _validate_config_id(config_id) # Validate config_id is UUID
if not isinstance(config_id, UUID):
if isinstance(config_id, str):
try:
config_id = UUID(config_id)
except ValueError:
raise InvalidConfigError(
f"Invalid UUID format for config_id: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
else:
raise InvalidConfigError(
f"config_id must be UUID or valid UUID string, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
# Step 1: Get config and workspace # Step 1: Get config and workspace
db_query_start = time.time() db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id) result = MemoryConfigRepository.get_config_with_workspace(self.db, config_id)
db_query_time = time.time() - db_query_start db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result: if not result:
@@ -136,14 +111,14 @@ class MemoryConfigService:
"Configuration not found in database", "Configuration not found in database",
extra={ extra={
"operation": "load_memory_config", "operation": "load_memory_config",
"config_id": validated_config_id, "config_id": str(config_id),
"load_result": "not_found", "load_result": "not_found",
"elapsed_ms": elapsed_ms, "elapsed_ms": elapsed_ms,
"service": service_name, "service": service_name,
}, },
) )
raise ConfigurationError( raise ConfigurationError(
f"Configuration {validated_config_id} not found in database" f"Configuration {config_id} not found in database"
) )
memory_config, workspace = result memory_config, workspace = result
@@ -151,7 +126,7 @@ class MemoryConfigService:
# Step 2: Validate embedding model (returns both UUID and name) # Step 2: Validate embedding model (returns both UUID and name)
embed_start = time.time() embed_start = time.time()
embedding_uuid, embedding_name = validate_embedding_model( embedding_uuid, embedding_name = validate_embedding_model(
validated_config_id, config_id,
memory_config.embedding_id, memory_config.embedding_id,
self.db, self.db,
workspace.tenant_id, workspace.tenant_id,
@@ -168,7 +143,7 @@ class MemoryConfigService:
self.db, self.db,
workspace.tenant_id, workspace.tenant_id,
required=True, required=True,
config_id=validated_config_id, config_id=config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
) )
llm_time = time.time() - llm_start llm_time = time.time() - llm_start
@@ -185,7 +160,7 @@ class MemoryConfigService:
self.db, self.db,
workspace.tenant_id, workspace.tenant_id,
required=False, required=False,
config_id=validated_config_id, config_id=config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
) )
rerank_time = time.time() - rerank_start rerank_time = time.time() - rerank_start
@@ -243,7 +218,7 @@ class MemoryConfigService:
extra={ extra={
"operation": "load_memory_config", "operation": "load_memory_config",
"service": service_name, "service": service_name,
"config_id": validated_config_id, "config_id": str(config_id),
"config_name": config.config_name, "config_name": config.config_name,
"workspace_id": str(config.workspace_id), "workspace_id": str(config.workspace_id),
"load_result": "success", "load_result": "success",

View File

@@ -12,6 +12,7 @@
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -87,7 +88,7 @@ class MemoryForgetService:
async def _get_forgetting_components( async def _get_forgetting_components(
self, self,
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]: ) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]:
""" """
获取遗忘引擎组件(计算器、策略、调度器) 获取遗忘引擎组件(计算器、策略、调度器)
@@ -294,7 +295,7 @@ class MemoryForgetService:
end_user_id: str, end_user_id: str,
max_merge_batch_size: Optional[int] = None, max_merge_batch_size: Optional[int] = None,
min_days_since_access: Optional[int] = None, min_days_since_access: Optional[int] = None,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
手动触发遗忘周期 手动触发遗忘周期
@@ -389,7 +390,7 @@ class MemoryForgetService:
def read_forgetting_config( def read_forgetting_config(
self, self,
db: Session, db: Session,
config_id: int config_id: UUID
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取遗忘引擎配置 获取遗忘引擎配置
@@ -416,7 +417,7 @@ class MemoryForgetService:
def update_forgetting_config( def update_forgetting_config(
self, self,
db: Session, db: Session,
config_id: int, config_id: UUID,
update_fields: Dict[str, Any] update_fields: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -466,7 +467,7 @@ class MemoryForgetService:
self, self,
db: Session, db: Session,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取遗忘引擎统计信息 获取遗忘引擎统计信息
@@ -677,7 +678,7 @@ class MemoryForgetService:
db: Session, db: Session,
importance_score: float, importance_score: float,
days: int, days: int,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取遗忘曲线数据 获取遗忘曲线数据

View File

@@ -4,6 +4,7 @@ import os
import re import re
import time import time
import uuid import uuid
from uuid import UUID
from datetime import datetime, timezone from datetime import datetime, timezone
from math import ceil from math import ceil
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -382,7 +383,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
@celery_app.task(name="app.core.memory.agent.read_message", bind=True) @celery_app.task(name="app.core.memory.agent.read_message", bind=True)
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: uuid.UUID, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a read message via MemoryAgentService. """Celery task to process a read message via MemoryAgentService.
@@ -472,7 +473,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, end_user_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: def write_message_task(self, end_user_id: str, message: str, config_id: uuid.UUID, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
@@ -1084,7 +1085,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True) @celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]: def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""定时任务:运行遗忘周期 """定时任务:运行遗忘周期
定期执行遗忘周期,识别并融合低激活值的知识节点。 定期执行遗忘周期,识别并融合低激活值的知识节点。