Merge pull request #225 from SuanmoSuanyangTechnology/fix/memory_bug_fix

Fix/memory bug fix
This commit is contained in:
Mark
2026-01-28 16:10:58 +08:00
committed by GitHub
7 changed files with 67 additions and 18 deletions

View File

@@ -21,6 +21,7 @@ from app.schemas.response_schema import ApiResponse
from app.services.emotion_config_service import EmotionConfigService
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -46,7 +47,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: UUID = Query(..., description="配置ID"),
config_id: UUID|int = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -79,7 +80,7 @@ def get_emotion_config(
f"用户 {current_user.username} 请求获取情绪配置",
extra={"config_id": config_id}
)
config_id=resolve_config_id(config_id, db)
# 初始化服务
config_service = EmotionConfigService(db)

View File

@@ -34,7 +34,7 @@ from app.schemas.memory_storage_schema import (
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -84,6 +84,9 @@ async def trigger_forgetting_cycle(
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(int(config_id), db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
@@ -129,7 +132,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: UUID,
config_id: UUID|int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -158,6 +161,7 @@ async def read_forgetting_config(
)
try:
config_id=resolve_config_id(config_id, db)
# 调用服务层读取配置
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
@@ -195,6 +199,8 @@ async def update_forgetting_config(
ApiResponse: 包含更新结果的响应
"""
workspace_id = current_user.current_workspace_id
payload.config_id=resolve_config_id(int(payload.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
@@ -255,12 +261,10 @@ async def get_forgetting_stats(
ApiResponse: 包含统计信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 end_user_id通过它获取 config_id
config_id = None
if end_user_id:
@@ -269,6 +273,7 @@ async def get_forgetting_stats(
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
@@ -325,7 +330,7 @@ async def get_forgetting_curve(
ApiResponse: 包含遗忘曲线数据的响应
"""
workspace_id = current_user.current_workspace_id
request.config_id = resolve_config_id(int(request.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")

View File

@@ -25,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
load_dotenv()
api_logger = get_api_logger()
@@ -157,17 +159,19 @@ async def start_workspace_reflection(
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: uuid.UUID,
config_id: uuid.UUID|int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询memory_config表中的反思配置信息"""
try:
config_id=resolve_config_id(config_id,db)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
memory_config_id = resolve_config_id(result.config_id, db)
# 构建返回数据
reflection_config = {
"config_id": result.config_id,
"config_id": memory_config_id,
"reflection_enabled": result.enable_self_reflexion,
"reflection_period_in_hours": result.iteration_period,
"reflexion_range": result.reflexion_range,
@@ -192,7 +196,7 @@ async def start_reflection_configs(
@router.get("/reflection/run")
async def reflection_run(
config_id: UUID,
config_id: UUID|int,
language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
@@ -200,7 +204,7 @@ async def reflection_run(
"""Activate the reflection function for all matching applications in the workspace"""
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
config_id = resolve_config_id(config_id, db)
# 使用MemoryConfigRepository查询反思配置
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result:

View File

@@ -24,6 +24,8 @@ from app.schemas.memory_storage_schema import (
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# 获取数据库专用日志器
db_logger = get_db_logger()
# 获取配置专用日志器
@@ -410,7 +412,7 @@ class MemoryConfigRepository:
raise
@staticmethod
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
def get_extracted_config(db: Session, config_id: UUID |int) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
@@ -420,8 +422,8 @@ class MemoryConfigRepository:
Returns:
Optional[Dict]: 萃取配置字典不存在则返回None
"""
config_id=resolve_config_id(config_id,db)
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:

View File

@@ -147,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: uuid.UUID = Field("config_id", description="配置唯一标识UUID")
config_id: Union[uuid.UUID, int] = Field(..., description="配置唯一标识UUID或int)")
user_id: str = Field("user_id", description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
@@ -423,8 +423,8 @@ class ForgettingConfigResponse(BaseModel):
class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: uuid.UUID = Field(..., description="配置ID")
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识UUID或int)")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -499,7 +499,7 @@ class ForgettingCurveRequest(BaseModel):
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Optional[uuid.UUID] = Field(None, description="配置ID可选如果为None则使用默认配置")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
class ForgettingCurveResponse(BaseModel):

View File

@@ -334,7 +334,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-')
print(langchain_messages)
print(100*'-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,

View File

@@ -0,0 +1,35 @@
"""
Configuration utility functions
Shared utilities for configuration handling to avoid circular imports.
"""
from uuid import UUID
from sqlalchemy.orm import Session
def resolve_config_id(config_id: UUID | int, db: Session) -> UUID:
"""
解析 config_id如果是整数则通过 config_id_old 查找对应的 UUID
Args:
config_id: 配置IDUUID 或整数)
db: 数据库会话
Returns:
UUID: 解析后的配置ID
Raises:
ValueError: 当找不到对应的配置时
"""
if isinstance(config_id, int):
from app.models.memory_config_model import MemoryConfig
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == config_id
).first()
if not memory_config:
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id
return config_id