refactor(memory): restructure memory agent and config management

- Reorganize imports and remove unused dependencies across memory agent controllers
- Extract config validation logic into dedicated validators module
- Create new memory_config_model and memory_config_schema for configuration management
- Implement memory_config_service for centralized config handling
- Add embedder_utils module for embedding model utilities
- Refactor memory agent service to use new config validation framework
- Clean up configuration files (remove config.json, testdata.json, dbrun.json)
- Remove deprecated hybrid_chatbot.py and config overrides
- Update logging configuration and error handling across memory modules
- Consolidate LLM and embedding model validation into validators
- Improve code organization and reduce duplication in memory storage services
- Enhance type classification and verification tools with better error handling
This commit is contained in:
Ke Sun
2025-12-21 20:32:41 +08:00
parent 7386ea32f1
commit 1e3ba39150
53 changed files with 3122 additions and 3407 deletions

View File

@@ -1,36 +1,28 @@
import json
import time
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.db import get_db
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.rag.llm.cv_model import QWenCV
from app.models import ModelApiKey, Knowledge
from app.services.memory_agent_service import MemoryAgentService
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from typing import List, Optional
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services import task_service, workspace_service
from app.core.logging_config import get_api_logger
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends, File, UploadFile, Form
from app.repositories import knowledge_repository
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
import os
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
# 加载.env文件
load_dotenv()
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_agent_service = MemoryAgentService()
router = APIRouter(
@@ -39,95 +31,6 @@ router = APIRouter(
)
def validate_config_id(config_id: int, db: Session) -> int:
"""
Validate and ensure config_id is available, valid, and exists in database.
Args:
config_id: Configuration ID to validate
db: Database session for checking existence
Returns:
int: Validated config_id
Raises:
ValueError: If config_id is None, invalid, or doesn't exist in database
"""
if config_id is None:
api_logger.info("config_id is required but was not provided")
config_id = os.getenv('config_id')
if config_id is None:
raise ValueError("config_id is required but was not provided")
# Check if config exists in database
try:
from app.models.data_config_model import DataConfig
from app.models.models_model import ModelConfig
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config is None:
error_msg = f"Configuration with config_id={config_id} does not exist in database"
api_logger.error(error_msg)
raise ValueError(error_msg)
# Validate llm_id exists and is usable
if config.llm_id:
try:
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
if llm_config is None:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not llm_config.is_active:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating LLM model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no llm_id set")
raise ValueError(f"Config {config_id} has no llm_id set")
# Validate embedding_id exists and is usable
if config.embedding_id:
try:
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
if embedding_config is None:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not embedding_config.is_active:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating embedding model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no embedding_id set")
raise ValueError(f"Config {config_id} has no embedding_id set")
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
return config_id
except ValueError:
# Re-raise ValueError from above
raise
except Exception as e:
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
@@ -225,12 +128,7 @@ async def write_server(
Returns:
Response with write operation status
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -270,8 +168,14 @@ async def write_server(
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except Exception as e:
api_logger.error(f"Write operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@@ -292,12 +196,7 @@ async def write_server_async(
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -352,12 +251,7 @@ async def read_server(
Returns:
Response with query answer
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
@@ -390,8 +284,14 @@ async def read_server(
user_rag_memory_id
)
return success(data=result, msg="回复对话消息成功")
except Exception as e:
api_logger.error(f"Read operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@@ -456,12 +356,7 @@ async def read_server_async(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")