refactor(memory): restructure memory agent and config management
- Reorganize imports and remove unused dependencies across memory agent controllers - Extract config validation logic into dedicated validators module - Create new memory_config_model and memory_config_schema for configuration management - Implement memory_config_service for centralized config handling - Add embedder_utils module for embedding model utilities - Refactor memory agent service to use new config validation framework - Clean up configuration files (remove config.json, testdata.json, dbrun.json) - Remove deprecated hybrid_chatbot.py and config overrides - Update logging configuration and error handling across memory modules - Consolidate LLM and embedding model validation into validators - Improve code organization and reduce duplication in memory storage services - Enhance type classification and verification tools with better error handling
This commit is contained in:
@@ -1,36 +1,28 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
from app.db import get_db
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.models import ModelApiKey, Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from typing import List, Optional
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services import task_service, workspace_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_agent_service = MemoryAgentService()
|
||||
|
||||
router = APIRouter(
|
||||
@@ -39,95 +31,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def validate_config_id(config_id: int, db: Session) -> int:
|
||||
"""
|
||||
Validate and ensure config_id is available, valid, and exists in database.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID to validate
|
||||
db: Database session for checking existence
|
||||
|
||||
Returns:
|
||||
int: Validated config_id
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||
"""
|
||||
if config_id is None:
|
||||
api_logger.info("config_id is required but was not provided")
|
||||
config_id = os.getenv('config_id')
|
||||
if config_id is None:
|
||||
raise ValueError("config_id is required but was not provided")
|
||||
|
||||
|
||||
# Check if config exists in database
|
||||
try:
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.models_model import ModelConfig
|
||||
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if config is None:
|
||||
error_msg = f"Configuration with config_id={config_id} does not exist in database"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate llm_id exists and is usable
|
||||
if config.llm_id:
|
||||
try:
|
||||
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
|
||||
if llm_config is None:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not llm_config.is_active:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating LLM model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no llm_id set")
|
||||
raise ValueError(f"Config {config_id} has no llm_id set")
|
||||
|
||||
# Validate embedding_id exists and is usable
|
||||
if config.embedding_id:
|
||||
try:
|
||||
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
|
||||
if embedding_config is None:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not embedding_config.is_active:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating embedding model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no embedding_id set")
|
||||
raise ValueError(f"Config {config_id} has no embedding_id set")
|
||||
|
||||
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
|
||||
return config_id
|
||||
except ValueError:
|
||||
# Re-raise ValueError from above
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -225,12 +128,7 @@ async def write_server(
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -270,8 +168,14 @@ async def write_server(
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="写入成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write operation error: {str(e)}")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@@ -292,12 +196,7 @@ async def write_server_async(
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -352,12 +251,7 @@ async def read_server(
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -390,8 +284,14 @@ async def read_server(
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read operation error: {str(e)}")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
|
||||
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||
|
||||
|
||||
@@ -456,12 +356,7 @@ async def read_server_async(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user