Merge pull request #225 from SuanmoSuanyangTechnology/fix/memory_bug_fix
Fix/memory bug fix
This commit is contained in:
@@ -21,6 +21,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -46,7 +47,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: UUID = Query(..., description="配置ID"),
|
||||
config_id: UUID|int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -79,7 +80,7 @@ def get_emotion_config(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -84,6 +84,9 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(int(config_id), db)
|
||||
|
||||
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
@@ -129,7 +132,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: UUID,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -158,6 +161,7 @@ async def read_forgetting_config(
|
||||
)
|
||||
|
||||
try:
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
@@ -195,6 +199,8 @@ async def update_forgetting_config(
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id=resolve_config_id(int(payload.config_id), db)
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
@@ -255,12 +261,10 @@ async def get_forgetting_stats(
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if end_user_id:
|
||||
@@ -269,6 +273,7 @@ async def get_forgetting_stats(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
@@ -325,7 +330,7 @@ async def get_forgetting_curve(
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
request.config_id = resolve_config_id(int(request.config_id), db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
|
||||
@@ -25,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -157,17 +159,19 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: uuid.UUID,
|
||||
config_id: uuid.UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
@@ -192,7 +196,7 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: UUID,
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -200,7 +204,7 @@ async def reflection_run(
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
|
||||
@@ -24,6 +24,8 @@ from app.schemas.memory_storage_schema import (
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
# 获取配置专用日志器
|
||||
@@ -410,7 +412,7 @@ class MemoryConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
|
||||
def get_extracted_config(db: Session, config_id: UUID |int) -> Optional[Dict]:
|
||||
"""获取萃取配置,通过主键查询某条配置
|
||||
|
||||
Args:
|
||||
@@ -420,8 +422,8 @@ class MemoryConfigRepository:
|
||||
Returns:
|
||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||
"""
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
|
||||
@@ -147,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
|
||||
# Composite key identifying a config row
|
||||
class ConfigKey(BaseModel): # 配置参数键模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID)")
|
||||
config_id: Union[uuid.UUID, int] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||
|
||||
@@ -423,8 +423,8 @@ class ForgettingConfigResponse(BaseModel):
|
||||
class ForgettingConfigUpdateRequest(BaseModel):
|
||||
"""遗忘引擎配置更新请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: uuid.UUID = Field(..., description="配置ID")
|
||||
|
||||
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||
@@ -499,7 +499,7 @@ class ForgettingCurveRequest(BaseModel):
|
||||
|
||||
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
||||
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
||||
config_id: Optional[uuid.UUID] = Field(None, description="配置ID(可选,如果为None则使用默认配置)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
|
||||
|
||||
class ForgettingCurveResponse(BaseModel):
|
||||
|
||||
@@ -334,7 +334,9 @@ class MemoryAgentService:
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
print(100*'-')
|
||||
print(langchain_messages)
|
||||
print(100*'-')
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
|
||||
35
api/app/utils/config_utils.py
Normal file
35
api/app/utils/config_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Configuration utility functions
|
||||
|
||||
Shared utilities for configuration handling to avoid circular imports.
|
||||
"""
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def resolve_config_id(config_id: UUID | int, db: Session) -> UUID:
|
||||
"""
|
||||
解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID
|
||||
|
||||
Args:
|
||||
config_id: 配置ID(UUID 或整数)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
UUID: 解析后的配置ID
|
||||
|
||||
Raises:
|
||||
ValueError: 当找不到对应的配置时
|
||||
"""
|
||||
if isinstance(config_id, int):
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
memory_config = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id_old == config_id
|
||||
).first()
|
||||
|
||||
if not memory_config:
|
||||
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
|
||||
|
||||
return memory_config.config_id
|
||||
|
||||
return config_id
|
||||
Reference in New Issue
Block a user