refactor(memory): restructure memory agent and config management

- Reorganize imports and remove unused dependencies across memory agent controllers
- Extract config validation logic into dedicated validators module
- Create new memory_config_model and memory_config_schema for configuration management
- Implement memory_config_service for centralized config handling
- Add embedder_utils module for embedding model utilities
- Refactor memory agent service to use new config validation framework
- Clean up configuration files (remove config.json, testdata.json, dbrun.json)
- Remove deprecated hybrid_chatbot.py and config overrides
- Update logging configuration and error handling across memory modules
- Consolidate LLM and embedding model validation into validators
- Improve code organization and reduce duplication in memory storage services
- Enhance type classification and verification tools with better error handling
This commit is contained in:
Ke Sun
2025-12-21 20:32:41 +08:00
parent 7386ea32f1
commit 1e3ba39150
53 changed files with 3122 additions and 3407 deletions

View File

@@ -1,36 +1,28 @@
import json
import time
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.db import get_db
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.rag.llm.cv_model import QWenCV
from app.models import ModelApiKey, Knowledge
from app.services.memory_agent_service import MemoryAgentService
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from typing import List, Optional
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services import task_service, workspace_service
from app.core.logging_config import get_api_logger
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends, File, UploadFile, Form
from app.repositories import knowledge_repository
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
import os
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
# 加载.env文件
load_dotenv()
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_agent_service = MemoryAgentService()
router = APIRouter(
@@ -39,95 +31,6 @@ router = APIRouter(
)
def validate_config_id(config_id: int, db: Session) -> int:
"""
Validate and ensure config_id is available, valid, and exists in database.
Args:
config_id: Configuration ID to validate
db: Database session for checking existence
Returns:
int: Validated config_id
Raises:
ValueError: If config_id is None, invalid, or doesn't exist in database
"""
if config_id is None:
api_logger.info("config_id is required but was not provided")
config_id = os.getenv('config_id')
if config_id is None:
raise ValueError("config_id is required but was not provided")
# Check if config exists in database
try:
from app.models.data_config_model import DataConfig
from app.models.models_model import ModelConfig
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config is None:
error_msg = f"Configuration with config_id={config_id} does not exist in database"
api_logger.error(error_msg)
raise ValueError(error_msg)
# Validate llm_id exists and is usable
if config.llm_id:
try:
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
if llm_config is None:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not llm_config.is_active:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating LLM model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no llm_id set")
raise ValueError(f"Config {config_id} has no llm_id set")
# Validate embedding_id exists and is usable
if config.embedding_id:
try:
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
if embedding_config is None:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not embedding_config.is_active:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating embedding model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no embedding_id set")
raise ValueError(f"Config {config_id} has no embedding_id set")
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
return config_id
except ValueError:
# Re-raise ValueError from above
raise
except Exception as e:
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
@@ -225,12 +128,7 @@ async def write_server(
Returns:
Response with write operation status
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -270,8 +168,14 @@ async def write_server(
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except Exception as e:
api_logger.error(f"Write operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@@ -292,12 +196,7 @@ async def write_server_async(
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -352,12 +251,7 @@ async def read_server(
Returns:
Response with query answer
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
@@ -390,8 +284,14 @@ async def read_server(
user_rag_memory_id
)
return success(data=result, msg="回复对话消息成功")
except Exception as e:
api_logger.error(f"Read operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@@ -456,12 +356,7 @@ async def read_server_async(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")

View File

@@ -1,45 +1,45 @@
from typing import Optional, Union
import os
import uuid
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, UploadFile
from fastapi.responses import StreamingResponse
from typing import Optional
from app.db import get_db
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_storage_service import (
MemoryStorageService,
DataConfigService,
kb_type_distribution,
search_dialogue,
search_chunk,
search_statement,
search_entity,
search_all,
search_detials,
search_edges,
search_entity_graph,
MemoryStorageService,
analytics_hot_memory_tags,
analytics_memory_insight_report,
analytics_recent_activity_stats,
analytics_user_summary,
kb_type_distribution,
search_all,
search_chunk,
search_detials,
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import (
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
ConfigPilotRun,
)
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
# Get API logger
api_logger = get_api_logger()
@@ -329,8 +329,10 @@ async def pilot_run(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StreamingResponse:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}"
)
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload),
@@ -338,8 +340,8 @@ async def pilot_run(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
"X-Accel-Buffering": "no",
},
)
"""
@@ -528,8 +530,8 @@ async def get_user_summary_api(
except Exception as e:
api_logger.error(f"User summary failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
from app.core.memory.utils.self_reflexion_utils import self_reflexion
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""