Merge #44 into develop from refactor/memory-config-management
Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management * refactor/memory-config-management: (7 commits) refactor(memory): restructure memory agent and config management Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management refactor(memory): restructure memory system and improve configuration management Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management refactor(memory): reorganize imports and move MemoryClientFactory to utils feat(memory): make config_id optional and improve configuration validation Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management Signed-off-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/44
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from celery import Celery
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
@@ -13,7 +13,6 @@ celery_app = Celery(
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
)
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
|
||||
# 配置使用本地队列,避免与远程 worker 冲突
|
||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||
@@ -22,6 +21,7 @@ celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
||||
|
||||
# macOS 兼容性配置
|
||||
import platform
|
||||
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
# 设置环境变量解决 fork 问题
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
@@ -10,22 +10,21 @@ Routes:
|
||||
POST /emotion/suggestions - 获取个性化情绪建议
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.emotion_schema import (
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest,
|
||||
EmotionTagsRequest,
|
||||
EmotionWordcloudRequest,
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -211,13 +210,28 @@ async def get_emotion_suggestions(
|
||||
"""
|
||||
try:
|
||||
# 验证 config_id(如果提供)
|
||||
# 获取终端用户关联的配置
|
||||
config_id = request.config_id
|
||||
if config_id is not None:
|
||||
from app.controllers.memory_agent_controller import validate_config_id
|
||||
if config_id is None:
|
||||
# 如果没有提供 config_id,尝试获取用户关联的配置
|
||||
try:
|
||||
config_id = validate_config_id(config_id, db)
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(request.group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
|
||||
else:
|
||||
# 如果提供了 config_id,验证其有效性
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
config = config_service.get_config_by_id(config_id)
|
||||
if not config:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
|
||||
except Exception as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
@@ -230,7 +244,7 @@ async def get_emotion_suggestions(
|
||||
# 调用服务层
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
config_id=config_id
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
|
||||
@@ -1,36 +1,28 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
from app.db import get_db
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.models import ModelApiKey, Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from typing import List, Optional
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services import task_service, workspace_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_agent_service = MemoryAgentService()
|
||||
|
||||
router = APIRouter(
|
||||
@@ -39,95 +31,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def validate_config_id(config_id: int, db: Session) -> int:
|
||||
"""
|
||||
Validate and ensure config_id is available, valid, and exists in database.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID to validate
|
||||
db: Database session for checking existence
|
||||
|
||||
Returns:
|
||||
int: Validated config_id
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||
"""
|
||||
if config_id is None:
|
||||
api_logger.info("config_id is required but was not provided")
|
||||
config_id = os.getenv('config_id')
|
||||
if config_id is None:
|
||||
raise ValueError("config_id is required but was not provided")
|
||||
|
||||
|
||||
# Check if config exists in database
|
||||
try:
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.models_model import ModelConfig
|
||||
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if config is None:
|
||||
error_msg = f"Configuration with config_id={config_id} does not exist in database"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate llm_id exists and is usable
|
||||
if config.llm_id:
|
||||
try:
|
||||
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
|
||||
if llm_config is None:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not llm_config.is_active:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating LLM model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no llm_id set")
|
||||
raise ValueError(f"Config {config_id} has no llm_id set")
|
||||
|
||||
# Validate embedding_id exists and is usable
|
||||
if config.embedding_id:
|
||||
try:
|
||||
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
|
||||
if embedding_config is None:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not embedding_config.is_active:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating embedding model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no embedding_id set")
|
||||
raise ValueError(f"Config {config_id} has no embedding_id set")
|
||||
|
||||
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
|
||||
return config_id
|
||||
except ValueError:
|
||||
# Re-raise ValueError from above
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -225,12 +128,7 @@ async def write_server(
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -265,13 +163,20 @@ async def write_server(
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
config_id,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="写入成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write operation error: {str(e)}")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@@ -292,12 +197,7 @@ async def write_server_async(
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -352,12 +252,7 @@ async def read_server(
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -386,12 +281,19 @@ async def read_server(
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read operation error: {str(e)}")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
|
||||
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||
|
||||
|
||||
@@ -456,12 +358,7 @@ async def read_server_async(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -653,6 +550,7 @@ async def get_write_task_result(
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
@@ -666,7 +564,11 @@ async def status_type(
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
try:
|
||||
result = await memory_agent_service.classify_message_type(user_input.message)
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
user_input.message,
|
||||
user_input.config_id,
|
||||
db
|
||||
)
|
||||
return success(data=result)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Message type classification failed: {str(e)}")
|
||||
@@ -741,6 +643,7 @@ async def get_hot_memory_tags_by_user_api(
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
@@ -764,7 +667,8 @@ async def get_user_profile_api(
|
||||
try:
|
||||
result = await memory_agent_service.get_user_profile(
|
||||
end_user_id=end_user_id,
|
||||
current_user_id=str(current_user.id)
|
||||
current_user_id=str(current_user.id),
|
||||
db=db
|
||||
)
|
||||
return success(data=result, msg="获取用户详情成功")
|
||||
except Exception as e:
|
||||
@@ -799,4 +703,41 @@ async def get_user_profile_api(
|
||||
# )
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"API docs retrieval failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
|
||||
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
|
||||
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
|
||||
通过以下流程获取配置:
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的响应
|
||||
"""
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"End user config not found: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
@@ -1,22 +1,27 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
ReflectionEngine,
|
||||
)
|
||||
from app.core.response_utils import success
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
|
||||
from app.dependencies import get_current_user
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
MemoryReflectionService,
|
||||
WorkspaceAppService,
|
||||
)
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.services.memory_storage_service import (
|
||||
MemoryStorageService,
|
||||
DataConfigService,
|
||||
kb_type_distribution,
|
||||
search_dialogue,
|
||||
search_chunk,
|
||||
search_statement,
|
||||
search_entity,
|
||||
search_all,
|
||||
search_detials,
|
||||
search_edges,
|
||||
search_entity_graph,
|
||||
analytics_hot_memory_tags,
|
||||
analytics_recent_activity_stats,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
ConfigPilotRun,
|
||||
GenerateCacheRequest,
|
||||
)
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
)
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
GenerateCacheRequest,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_storage_service import (
|
||||
DataConfigService,
|
||||
MemoryStorageService,
|
||||
analytics_hot_memory_tags,
|
||||
analytics_recent_activity_stats,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_chunk,
|
||||
search_detials,
|
||||
search_dialogue,
|
||||
search_edges,
|
||||
search_entity,
|
||||
search_entity_graph,
|
||||
search_statement,
|
||||
)
|
||||
from app.services.user_memory_service import analytics_user_summary
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -335,8 +335,10 @@ async def pilot_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
@@ -344,8 +346,8 @@ async def pilot_run(
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -508,6 +510,20 @@ async def get_recent_activity_stats_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"User summary requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await analytics_user_summary(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"User summary failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
|
||||
|
||||
|
||||
@router.get("/self_reflexion")
|
||||
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
|
||||
@@ -9,18 +9,19 @@ LangChain Agent 封装
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -198,10 +199,24 @@ class LangChainAgent:
|
||||
"""
|
||||
message_chat= message
|
||||
start_time = time.time()
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
actual_config_id = config_id
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
@@ -295,10 +310,24 @@ class LangChainAgent:
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
actual_config_id = config_id
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
@@ -81,6 +82,7 @@ class Settings:
|
||||
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
|
||||
|
||||
# Langfuse configuration
|
||||
LANGFUSE_ENABLED: bool = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
@@ -156,9 +158,6 @@ class Settings:
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
|
||||
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
|
||||
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
@@ -181,65 +180,6 @@ class Settings:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
def get_memory_config_path(self, config_file: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module configuration files.
|
||||
|
||||
Args:
|
||||
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
|
||||
|
||||
Returns:
|
||||
Full path to the config file
|
||||
"""
|
||||
if not config_file:
|
||||
config_file = self.MEMORY_CONFIG_FILE
|
||||
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
|
||||
|
||||
def load_memory_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load memory module configuration from config.json.
|
||||
|
||||
Returns:
|
||||
Dictionary containing memory configuration
|
||||
"""
|
||||
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
|
||||
return {}
|
||||
|
||||
def load_memory_runtime_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load memory module runtime configuration from runtime.json.
|
||||
|
||||
Returns:
|
||||
Dictionary containing runtime configuration
|
||||
"""
|
||||
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
|
||||
try:
|
||||
with open(runtime_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
|
||||
return {"selections": {}}
|
||||
|
||||
def load_memory_dbrun_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load memory module database run configuration from dbrun.json.
|
||||
|
||||
Returns:
|
||||
Dictionary containing dbrun configuration
|
||||
"""
|
||||
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
|
||||
return {"selections": {}}
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -326,7 +326,7 @@ def log_prompt_rendering(prompt_type: str, content: str) -> None:
|
||||
logger.info(log_message)
|
||||
|
||||
|
||||
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
|
||||
def log_template_rendering(template_name: str, context: Optional[dict] = None) -> None:
|
||||
"""Log template rendering information.
|
||||
|
||||
Logs the template name and context keys for debugging template rendering.
|
||||
@@ -575,6 +575,43 @@ def get_named_logger(name: str) -> logging.Logger:
|
||||
return get_agent_logger(name)
|
||||
|
||||
|
||||
def get_config_logger() -> logging.Logger:
|
||||
"""Get a specialized logger for memory configuration operations.
|
||||
|
||||
Returns a logger configured specifically for configuration loading, validation,
|
||||
and model resolution operations with:
|
||||
- Logger name: memory.config
|
||||
- Output: Inherits from root logger (console + file)
|
||||
- Level: Inherits from root logger
|
||||
- Format: Standard format with timing information
|
||||
|
||||
This logger is optimized for configuration operations and includes
|
||||
structured logging for timing, validation steps, and error context.
|
||||
|
||||
Returns:
|
||||
Logger configured for memory configuration operations
|
||||
|
||||
Example:
|
||||
>>> logger = get_config_logger()
|
||||
>>> logger.info("Loading configuration", extra={
|
||||
... "config_id": 123,
|
||||
... "workspace_id": "uuid-here",
|
||||
... "operation": "load_config"
|
||||
... })
|
||||
"""
|
||||
# Ensure memory logging is initialized
|
||||
if not LoggingConfig._memory_loggers_initialized:
|
||||
LoggingConfig.setup_memory_logging()
|
||||
|
||||
# Get configuration logger with memory namespace
|
||||
logger = logging.getLogger("memory.config")
|
||||
|
||||
# The logger automatically inherits handlers, formatters, and level from root logger
|
||||
# through Python's logging hierarchy, so no additional configuration is needed
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
"""Get a standard logger for memory module components.
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +25,8 @@ async def create_input_message(
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
multimodal_processor: MultimodalProcessor
|
||||
multimodal_processor: MultimodalProcessor,
|
||||
memory_config: MemoryConfig,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial tool call message from user input.
|
||||
@@ -46,6 +47,7 @@ async def create_input_message(
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
multimodal_processor: Processor for handling image/audio inputs
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
State update with AIMessage containing tool_call
|
||||
@@ -53,7 +55,7 @@ async def create_input_message(
|
||||
Examples:
|
||||
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||
>>> result = await create_input_message(
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
|
||||
... )
|
||||
>>> result["messages"][0].tool_calls[0]["name"]
|
||||
'Split_The_Problem'
|
||||
@@ -123,20 +125,24 @@ async def create_input_message(
|
||||
f"with ID: {tool_call_id}"
|
||||
)
|
||||
|
||||
# Build tool arguments
|
||||
tool_args = {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": tool_name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
},
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}]
|
||||
)
|
||||
|
||||
@@ -9,14 +9,14 @@ import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_content_payload,
|
||||
extract_tool_call_id,
|
||||
extract_content_payload
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,8 +38,9 @@ class ToolExecutionNode:
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: Callable,
|
||||
@@ -49,8 +50,9 @@ class ToolExecutionNode:
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
parameter_builder: ParameterBuilder,
|
||||
storage_type:str,
|
||||
user_rag_memory_id:str
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
):
|
||||
"""
|
||||
Initialize the tool execution node.
|
||||
@@ -63,6 +65,9 @@ class ToolExecutionNode:
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
storage_type: Storage type for the workspace
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = node_id
|
||||
@@ -72,9 +77,10 @@ class ToolExecutionNode:
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.parameter_builder = parameter_builder
|
||||
self.storage_type=storage_type
|
||||
self.user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
self.memory_config = memory_config
|
||||
|
||||
logger.info(
|
||||
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||
)
|
||||
@@ -124,8 +130,12 @@ class ToolExecutionNode:
|
||||
# Extract content payload using state extractors
|
||||
content = extract_content_payload(last_message)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
|
||||
)
|
||||
# Log raw message content for debugging
|
||||
if hasattr(last_message, 'content'):
|
||||
raw = last_message.content
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -143,8 +153,9 @@ class ToolExecutionNode:
|
||||
search_switch=self.search_switch,
|
||||
apply_id=self.apply_id,
|
||||
group_id=self.group_id,
|
||||
memory_config=self.memory_config,
|
||||
storage_type=self.storage_type,
|
||||
user_rag_memory_id=self.user_rag_memory_id
|
||||
user_rag_memory_id=self.user_rag_memory_id,
|
||||
)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||
@@ -179,7 +190,29 @@ class ToolExecutionNode:
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||
)
|
||||
|
||||
# Return the result directly - it already contains the messages list
|
||||
# Check for error in tool response
|
||||
error_entry = None
|
||||
if result and "messages" in result:
|
||||
for msg in result["messages"]:
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
import json
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and "error" in parsed:
|
||||
error_msg = parsed["error"]
|
||||
logger.warning(
|
||||
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
|
||||
)
|
||||
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Return result with error tracking if error was found
|
||||
if error_entry:
|
||||
result["errors"] = [error_entry]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -187,13 +220,15 @@ class ToolExecutionNode:
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return error as ToolMessage to maintain message chain consistency
|
||||
# Track error in state and return error message
|
||||
from langchain_core.messages import ToolMessage
|
||||
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
|
||||
return {
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Error executing tool: {str(e)}",
|
||||
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||
)
|
||||
]
|
||||
],
|
||||
"errors": [error_entry]
|
||||
}
|
||||
|
||||
@@ -1,38 +1,26 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from functools import partial
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
# Import new modular components
|
||||
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue
|
||||
from app.core.memory.agent.langgraph_graph.nodes import (
|
||||
ToolExecutionNode,
|
||||
create_input_message,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -44,9 +32,9 @@ redisdb=os.getenv('REDISDB')
|
||||
redispassword=os.getenv('REDISPASSWORD')
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
# 在工作流中添加循环计数更新
|
||||
# Update loop count in workflow
|
||||
async def update_loop_count(state):
|
||||
"""更新循环计数器"""
|
||||
"""Update loop counter"""
|
||||
current_count = state.get("loop_count", 0)
|
||||
return {"loop_count": current_count + 1}
|
||||
|
||||
@@ -54,13 +42,13 @@ async def update_loop_count(state):
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
messages = state["messages"]
|
||||
|
||||
# 添加边界检查
|
||||
# Add boundary check
|
||||
if not messages:
|
||||
return END
|
||||
counter.add(1) # 累加 1
|
||||
counter.add(1) # Increment by 1
|
||||
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
|
||||
logger.debug(f"[should_continue] Current loop count: {loop_count}")
|
||||
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
@@ -71,15 +59,15 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # 最大循环次数 3
|
||||
if loop_count < 2: # Maximum loop count is 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# 添加默认返回值,避免返回 None
|
||||
# Add default return value to avoid returning None
|
||||
counter.reset()
|
||||
return "Summary" # 或根据业务需求选择合适的默认值
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
@@ -115,8 +103,8 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# 添加默认返回值,避免返回 None
|
||||
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
|
||||
# Add default return value to avoid returning None
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
@@ -151,46 +139,7 @@ def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
# 在 input_sentence 函数中修改参数名称
|
||||
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1].content if messages else ""
|
||||
|
||||
if last_message.endswith('.jpg') or last_message.endswith('.png'):
|
||||
last_message=await picture_model_requests(last_message)
|
||||
if any(last_message.endswith(ext) for ext in audio_extensions):
|
||||
last_message=await Vico_recognition([last_message]).run()
|
||||
logger.debug(f"Audio recognition result: {last_message}")
|
||||
|
||||
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
namespace = str(id).split('_id_')[1]
|
||||
if 'verified_data' in str(last_message):
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
'sessionid': id,
|
||||
'messages_id': str(uuid_str),
|
||||
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
|
||||
"apply_id":apply_id,
|
||||
"group_id":group_id
|
||||
},
|
||||
"id": id + f'_{uuid_str}'
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
return 'Split_The_Problem' # Default case
|
||||
|
||||
|
||||
class ProblemExtensionNode:
|
||||
@@ -208,30 +157,28 @@ class ProblemExtensionNode:
|
||||
async def __call__(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name=='Input_Summary':
|
||||
tool_call =re.findall("'id': '(.*?)'",str(last_message))[0]
|
||||
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
# try:
|
||||
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
|
||||
# except:
|
||||
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name == 'Input_Summary':
|
||||
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
|
||||
else:
|
||||
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
|
||||
# Try to extract actual content payload from previous tool result
|
||||
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
extracted_payload = None
|
||||
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
|
||||
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
|
||||
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||
if m:
|
||||
extracted_payload = m.group(1)
|
||||
else:
|
||||
# 回退:直接尝试使用原始字符串
|
||||
# Fallback: use raw string directly
|
||||
extracted_payload = raw_msg
|
||||
|
||||
# 优先尝试将内容解析为 JSON
|
||||
# Try to parse content as JSON first
|
||||
try:
|
||||
content = json.loads(extracted_payload)
|
||||
except Exception:
|
||||
# 尝试从文本中提取 JSON 片段再解析
|
||||
# Try to extract JSON fragment from text and parse
|
||||
parsed = None
|
||||
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||
for cand in candidates:
|
||||
@@ -240,14 +187,14 @@ class ProblemExtensionNode:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# 如果仍然失败,则以原始字符串作为内容
|
||||
# If still fails, use raw string as content
|
||||
content = parsed if parsed is not None else extracted_payload
|
||||
|
||||
# 根据工具名称构建正确的参数
|
||||
# Build correct parameters based on tool name
|
||||
tool_args = {}
|
||||
|
||||
if self.tool_name == "Verify":
|
||||
# Verify工具需要context和usermessages参数
|
||||
# Verify tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
@@ -256,7 +203,7 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Retrieve":
|
||||
# Retrieve工具需要context和usermessages参数
|
||||
# Retrieve tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
@@ -266,9 +213,9 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
# Summary tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
@@ -276,24 +223,24 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary_fails":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
# Summary_fails tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name=='Input_Summary':
|
||||
tool_args["context"] =str(last_message)
|
||||
elif self.tool_name == 'Input_Summary':
|
||||
tool_args["context"] = str(last_message)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||
elif self.tool_name=='Retrieve_Summary' :
|
||||
elif self.tool_name == 'Retrieve_Summary':
|
||||
# Retrieve_Summary expects dict directly, not JSON string
|
||||
# content might be a JSON string, try to parse it
|
||||
if isinstance(content, str):
|
||||
@@ -320,7 +267,7 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
else:
|
||||
# 其他工具使用context参数
|
||||
# Other tools use context parameter
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
@@ -349,12 +296,24 @@ class ProblemExtensionNode:
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
|
||||
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
|
||||
"""
|
||||
Create a read graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
namespace: Namespace identifier
|
||||
tools: MCP tools loaded from session
|
||||
search_switch: Search mode switch ("0", "1", or "2")
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type (optional)
|
||||
user_rag_memory_id: User RAG memory ID (optional)
|
||||
"""
|
||||
memory = InMemorySaver()
|
||||
tool=[i.name for i in tools ]
|
||||
tool = [i.name for i in tools]
|
||||
logger.info(f"Initializing read graph with tools: {tool}")
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
# Extract tool functions
|
||||
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||
@@ -382,9 +341,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Retrieve_node = ToolExecutionNode(
|
||||
tool=Retrieve_,
|
||||
node_id="Retrieve_id",
|
||||
@@ -394,9 +354,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Verify_node = ToolExecutionNode(
|
||||
tool=Verify_,
|
||||
node_id="Verify_id",
|
||||
@@ -406,7 +367,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
Summary_node = ToolExecutionNode(
|
||||
@@ -418,9 +380,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Summary_fails_node = ToolExecutionNode(
|
||||
tool=Summary_fails_,
|
||||
node_id="Summary_fails_id",
|
||||
@@ -430,9 +393,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Retrieve_Summary_node = ToolExecutionNode(
|
||||
tool=Retrieve_Summary_,
|
||||
node_id="Retrieve_Summary_id",
|
||||
@@ -442,9 +406,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Input_Summary_node = ToolExecutionNode(
|
||||
tool=Input_Summary_,
|
||||
node_id="Input_Summary_id",
|
||||
@@ -454,16 +419,16 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
async def content_input_node(state):
|
||||
state_search_switch = state.get("search_switch", search_switch)
|
||||
|
||||
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||
|
||||
|
||||
return await create_input_message(
|
||||
state=state,
|
||||
tool_name=tool_name,
|
||||
@@ -471,7 +436,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
multimodal_processor=multimodal_processor
|
||||
multimodal_processor=multimodal_processor,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -501,8 +467,3 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
yield graph
|
||||
|
||||
|
||||
# 添加到文件末尾或创建新的执行脚本
|
||||
# 在 memory_agent_service.py 文件中添加以下函数
|
||||
|
||||
|
||||
@@ -128,6 +128,15 @@ def extract_content_payload(message: Any) -> Any:
|
||||
# For ToolMessages (responses from tools), extract from content
|
||||
if hasattr(message, "content"):
|
||||
raw_content = message.content
|
||||
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(raw_content, list):
|
||||
for block in raw_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
raw_content = block.get('text', '')
|
||||
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
|
||||
break
|
||||
|
||||
# If content is empty and this is an AIMessage with tool_calls,
|
||||
# extract from args (this handles the initial tool call from content_input)
|
||||
@@ -140,13 +149,16 @@ def extract_content_payload(message: Any) -> Any:
|
||||
|
||||
# If content is already a dict or list, return it directly
|
||||
if isinstance(raw_content, (dict, list)):
|
||||
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
|
||||
return raw_content
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(raw_content, str):
|
||||
# First, try direct JSON parsing
|
||||
try:
|
||||
return json.loads(raw_content)
|
||||
parsed = json.loads(raw_content)
|
||||
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
@@ -156,9 +168,12 @@ def extract_content_payload(message: Any) -> Any:
|
||||
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||
for candidate in json_candidates:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
parsed = json.loads(candidate)
|
||||
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail, return the raw content
|
||||
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
|
||||
return raw_content
|
||||
|
||||
@@ -1,116 +1,71 @@
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import add_messages, StateGraph
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
import warnings
|
||||
import sys
|
||||
from langchain_core.messages import AIMessage
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
|
||||
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
|
||||
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
logger.info("Loading MCP tools: %s", [t.name for t in tools])
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||
|
||||
if not data_type_tool or not data_write_tool:
|
||||
logger.error('不存在数据存储工具', exc_info=True)
|
||||
raise ValueError('不存在数据存储工具')
|
||||
# ToolNode
|
||||
write_node = ToolNode([data_write_tool])
|
||||
if not data_write_tool:
|
||||
logger.error("Data_write tool not found", exc_info=True)
|
||||
raise ValueError("Data_write tool not found")
|
||||
|
||||
write_node = ToolNode([data_write_tool])
|
||||
|
||||
async def call_model(state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
|
||||
# 调用 Data_type_differentiation 工具
|
||||
try:
|
||||
raw_result = await data_type_tool.ainvoke({
|
||||
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
})
|
||||
|
||||
# MCP工具返回的是列表格式,需要提取内容
|
||||
logger.debug(f"Data_type_differentiation raw result type: {type(raw_result)}, value: {raw_result}")
|
||||
|
||||
# 处理不同的返回格式
|
||||
if isinstance(raw_result, list) and len(raw_result) > 0:
|
||||
# MCP工具返回格式: [{"type": "text", "text": "..."}]
|
||||
result_text = raw_result[0].get("text", "{}") if isinstance(raw_result[0], dict) else str(raw_result[0])
|
||||
elif isinstance(raw_result, str):
|
||||
result_text = raw_result
|
||||
else:
|
||||
result_text = str(raw_result)
|
||||
|
||||
# 解析JSON字符串
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f"Failed to parse result as JSON: {result_text}, error: {je}")
|
||||
return {"messages": [AIMessage(content=json.dumps({
|
||||
"status": "error",
|
||||
"message": f"Invalid JSON response from Data_type_differentiation: {str(je)}"
|
||||
}))]}
|
||||
|
||||
# 检查是否有错误
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("message", "Unknown error in Data_type_differentiation")
|
||||
logger.error(f"Data_type_differentiation 返回错误: {error_msg}")
|
||||
return {"messages": [AIMessage(content=json.dumps({
|
||||
"status": "error",
|
||||
"message": error_msg
|
||||
}))]}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用 Data_type_differentiation 失败: {e}", exc_info=True)
|
||||
return {"messages": [AIMessage(content=json.dumps({
|
||||
"status": "error",
|
||||
"message": f"Data type differentiation failed: {str(e)}"
|
||||
}))]}
|
||||
|
||||
# 调用 Data_write,传递 config_id
|
||||
# Call Data_write directly with memory_config
|
||||
write_params = {
|
||||
"content": result.get("context", last_message.content if hasattr(last_message, 'content') else str(last_message)),
|
||||
"content": content,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id
|
||||
"user_id": user_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
# 如果提供了 config_id,添加到参数中
|
||||
if config_id:
|
||||
write_params["config_id"] = config_id
|
||||
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
|
||||
|
||||
try:
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
content = str(write_result)
|
||||
logger.info("写入内容: %s", content)
|
||||
return {"messages": [AIMessage(content=content)]}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用 Data_write 失败: {e}", exc_info=True)
|
||||
return {"messages": [AIMessage(content=json.dumps({
|
||||
"status": "error",
|
||||
"message": f"Data write failed: {str(e)}"
|
||||
}))]}
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
result_content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
result_content = str(write_result)
|
||||
logger.info("Write content: %s", result_content)
|
||||
return {"messages": [AIMessage(content=result_content)]}
|
||||
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", call_model)
|
||||
|
||||
@@ -10,19 +10,19 @@ Package structure:
|
||||
- models: Pydantic response models
|
||||
- services: Business logic services
|
||||
"""
|
||||
from app.core.memory.agent.mcp_server.server import (
|
||||
mcp,
|
||||
initialize_context,
|
||||
main,
|
||||
get_context_resource
|
||||
)
|
||||
# from app.core.memory.agent.mcp_server.server import (
|
||||
# mcp,
|
||||
# initialize_context,
|
||||
# main,
|
||||
# get_context_resource
|
||||
# )
|
||||
|
||||
# Import tools to register them (but don't export them)
|
||||
from app.core.memory.agent.mcp_server import tools
|
||||
# # Import tools to register them (but don't export them)
|
||||
# from app.core.memory.agent.mcp_server import tools
|
||||
|
||||
__all__ = [
|
||||
'mcp',
|
||||
'initialize_context',
|
||||
'main',
|
||||
'get_context_resource',
|
||||
]
|
||||
# __all__ = [
|
||||
# 'mcp',
|
||||
# 'initialize_context',
|
||||
# 'main',
|
||||
# 'get_context_resource',
|
||||
# ]
|
||||
@@ -6,19 +6,15 @@ in the context for dependency injection into tool functions.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -78,17 +74,11 @@ def initialize_context():
|
||||
logger.info("Registering session_store in context")
|
||||
mcp.session_store = store
|
||||
|
||||
# Register LLM client
|
||||
try:
|
||||
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
mcp.llm_client = llm_client
|
||||
logger.info("llm_client registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
|
||||
# 注册一个 None 值,避免工具调用时找不到资源
|
||||
mcp.llm_client = None
|
||||
logger.warning("llm_client set to None due to initialization failure")
|
||||
# Note: LLM client is NOT loaded at server startup
|
||||
# It should be loaded dynamically when needed, with config_id passed explicitly
|
||||
# to make_write_graph or make_read_graph functions
|
||||
logger.info("LLM client will be loaded dynamically with config_id when needed")
|
||||
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
|
||||
|
||||
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||
logger.info("Registering app_settings in context")
|
||||
@@ -124,26 +114,20 @@ def main():
|
||||
Initializes context and starts the server with SSE transport.
|
||||
"""
|
||||
try:
|
||||
# logger.info("Starting MCP server initialization")
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
logger.info("Starting MCP server initialization")
|
||||
# Initialize context resources
|
||||
initialize_context()
|
||||
|
||||
# Import and register tools
|
||||
# logger.info("Importing MCP tools")
|
||||
from app.core.memory.agent.mcp_server.tools import (
|
||||
# Import and register tools (imports trigger tool registration)
|
||||
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
|
||||
data_tools,
|
||||
problem_tools,
|
||||
retrieval_tools,
|
||||
verification_tools,
|
||||
summary_tools,
|
||||
data_tools
|
||||
verification_tools,
|
||||
)
|
||||
# logger.info("All MCP tools imported and registered")
|
||||
|
||||
# Log registered tools for debugging
|
||||
import asyncio
|
||||
tools_list = asyncio.run(mcp.list_tools())
|
||||
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
|
||||
# Tools are registered via imports above
|
||||
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
||||
|
||||
@@ -4,22 +4,22 @@ Parameter Builder for constructing tool call arguments.
|
||||
This service provides tool-specific parameter transformation logic
|
||||
to build correct arguments for each tool type.
|
||||
"""
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ParameterBuilder:
|
||||
"""Service for building tool call arguments based on tool type."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the parameter builder."""
|
||||
logger.info("ParameterBuilder initialized")
|
||||
|
||||
|
||||
def build_tool_args(
|
||||
self,
|
||||
tool_name: str,
|
||||
@@ -28,8 +28,9 @@ class ParameterBuilder:
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tool arguments based on tool type.
|
||||
@@ -48,6 +49,7 @@ class ParameterBuilder:
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -58,18 +60,19 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
base_args["storage_type"] = storage_type if storage_type is not None else ""
|
||||
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||
|
||||
# Tool-specific argument construction
|
||||
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
|
||||
# Verify expects dict context
|
||||
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
|
||||
# These tools expect dict context
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
"context": content if isinstance(content, dict) else {"content": content},
|
||||
**base_args
|
||||
}
|
||||
|
||||
|
||||
@@ -4,21 +4,31 @@ Search Service for executing hybrid search and processing results.
|
||||
This service provides clean search result processing with content extraction
|
||||
and deduplication.
|
||||
"""
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
|
||||
def __init__(self, memory_config: "MemoryConfig" = None):
|
||||
"""
|
||||
Initialize the search service.
|
||||
|
||||
Args:
|
||||
memory_config: Optional MemoryConfig for embedding model configuration.
|
||||
If not provided, must be passed to execute_hybrid_search.
|
||||
"""
|
||||
self.memory_config = memory_config
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
@@ -93,12 +103,13 @@ class SearchService:
|
||||
self,
|
||||
group_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
limit: int = 15,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False
|
||||
return_raw_results: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -112,6 +123,7 @@ class SearchService:
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
@@ -119,12 +131,17 @@ class SearchService:
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
|
||||
# Use provided memory_config or fall back to instance config
|
||||
config = memory_config or self.memory_config
|
||||
if not config:
|
||||
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
# Execute search using memory_config
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
@@ -132,7 +149,8 @@ class SearchService:
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
rerank_alpha=rerank_alpha
|
||||
memory_config=config,
|
||||
rerank_alpha=rerank_alpha,
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
|
||||
@@ -3,16 +3,20 @@ Data Tools for data type differentiation and writing.
|
||||
|
||||
This module contains MCP tools for distinguishing data types and writing data.
|
||||
"""
|
||||
import os
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import (
|
||||
DistinguishTypeResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -20,7 +24,8 @@ logger = get_agent_logger(__name__)
|
||||
@mcp.tool()
|
||||
async def Data_type_differentiation(
|
||||
ctx: Context,
|
||||
context: str
|
||||
context: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Distinguish the type of data (read or write).
|
||||
@@ -28,6 +33,7 @@ async def Data_type_differentiation(
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Text to analyze for type differentiation
|
||||
memory_config: MemoryConfig object containing LLM configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with the original text and 'type' field
|
||||
@@ -35,7 +41,11 @@ async def Data_type_differentiation(
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Get LLM client from memory_config using factory pattern
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Render template
|
||||
try:
|
||||
@@ -53,7 +63,7 @@ async def Data_type_differentiation(
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
@@ -98,7 +108,7 @@ async def Data_write(
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
config_id: str
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
@@ -109,7 +119,7 @@ async def Data_write(
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
config_id: Configuration ID for processing (optional, integer)
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
@@ -118,32 +128,28 @@ async def Data_write(
|
||||
# Ensure output directory exists
|
||||
os.makedirs("data_output", exist_ok=True)
|
||||
file_path = os.path.join("data_output", "user_data.csv")
|
||||
|
||||
# Write data using utility function
|
||||
try:
|
||||
await write(content, user_id, apply_id, group_id, config_id=config_id)
|
||||
logger.info(f"写入成功!Config ID: {config_id if config_id else 'None'}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": config_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"写入失败: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_write failed: {e}",
|
||||
exc_info=True
|
||||
|
||||
# Write data - clients are constructed inside write() from memory_config
|
||||
await write(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
"message": str(e),
|
||||
}
|
||||
|
||||
@@ -2,25 +2,24 @@
|
||||
Problem Tools for question segmentation and extension.
|
||||
|
||||
This module contains MCP tools for breaking down and extending user questions.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
ProblemBreakdownItem,
|
||||
ProblemBreakdownResponse,
|
||||
ExtendedQuestionItem,
|
||||
ProblemExtensionResponse
|
||||
ProblemExtensionResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -32,7 +31,8 @@ async def Split_The_Problem(
|
||||
sessionid: str,
|
||||
messages_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Segment the dialogue or sentence into sub-problems.
|
||||
@@ -44,17 +44,22 @@ async def Split_The_Problem(
|
||||
messages_id: Message identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
@@ -116,8 +121,8 @@ async def Split_The_Problem(
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info("问题拆分")
|
||||
logger.info(f"问题拆分结果==>>:{split_result}")
|
||||
logger.info("Problem splitting")
|
||||
logger.info(f"Problem split result: {split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
@@ -150,7 +155,7 @@ async def Split_The_Problem(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题拆分', duration)
|
||||
log_time('Problem splitting', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -160,8 +165,9 @@ async def Problem_Extension(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Extend the problem with additional sub-questions.
|
||||
@@ -172,6 +178,7 @@ async def Problem_Extension(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -179,12 +186,16 @@ async def Problem_Extension(
|
||||
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
@@ -250,8 +261,8 @@ async def Problem_Extension(
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info("问题扩展")
|
||||
logger.info(f"问题扩展==>>:{aggregated_dict}")
|
||||
logger.info("Problem extension")
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
@@ -290,4 +301,4 @@ async def Problem_Extension(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题扩展', duration)
|
||||
log_time('Problem extension', duration)
|
||||
|
||||
@@ -3,25 +3,24 @@ Retrieval Tools for database and context retrieval.
|
||||
|
||||
This module contains MCP tools for retrieving data using hybrid search.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import os
|
||||
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@@ -32,8 +31,9 @@ async def Retrieve(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve data from the database using hybrid search.
|
||||
@@ -44,6 +44,7 @@ async def Retrieve(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
@@ -66,6 +67,7 @@ async def Retrieve(
|
||||
}
|
||||
start = time.time()
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
@@ -77,7 +79,13 @@ async def Retrieve(
|
||||
if isinstance(context, dict):
|
||||
# Process dict context with extended questions
|
||||
all_items = []
|
||||
logger.info(f"Retrieve: context keys={list(context.keys())}")
|
||||
content, original = await Retriev_messages_deal(context)
|
||||
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
|
||||
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
|
||||
|
||||
if not original:
|
||||
logger.warning(f"Retrieve: original query is empty! context={context}")
|
||||
|
||||
# Extract all query items from content
|
||||
# content is like {original_question: [extended_questions...], ...}
|
||||
@@ -113,9 +121,11 @@ async def Retrieve(
|
||||
clean_content = ''
|
||||
raw_results=''
|
||||
cleaned_query = question
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
databases_anser.append({
|
||||
"Query_small": cleaned_query,
|
||||
@@ -206,9 +216,11 @@ async def Retrieve(
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = query
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||
dup_databases = {
|
||||
"Query": cleaned_query,
|
||||
@@ -236,7 +248,7 @@ async def Retrieve(
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||
)
|
||||
|
||||
@@ -279,4 +291,4 @@ async def Retrieve(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
log_time('Retrieval', duration)
|
||||
|
||||
@@ -2,33 +2,32 @@
|
||||
Summary Tools for data summarization.
|
||||
|
||||
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||
SummaryData,
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
RetrieveSummaryData,
|
||||
RetrieveSummaryResponse
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Summary_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -40,8 +39,9 @@ async def Summary(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize the verified data.
|
||||
@@ -52,6 +52,7 @@ async def Summary(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -59,12 +60,16 @@ async def Summary(
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -155,7 +160,7 @@ async def Summary(
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"验证之后的总结==>>:{aimessages}")
|
||||
logger.info(f"Summary after verification: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
@@ -163,7 +168,7 @@ async def Summary(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('总结', duration)
|
||||
log_time('Summary', duration)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -180,8 +185,9 @@ async def Retrieve_Summary(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
@@ -192,6 +198,7 @@ async def Retrieve_Summary(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -202,9 +209,13 @@ async def Retrieve_Summary(
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -212,6 +223,8 @@ async def Retrieve_Summary(
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
inner = context["content"]
|
||||
@@ -252,17 +265,19 @@ async def Retrieve_Summary(
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
|
||||
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
|
||||
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
|
||||
answer = None
|
||||
if isinstance(item, dict):
|
||||
if "Answer_Small" in item:
|
||||
answer = item["Answer_Small"]
|
||||
elif "Answer_Samll" in item:
|
||||
answer = item["Answer_Samll"]
|
||||
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
@@ -350,7 +365,7 @@ async def Retrieve_Summary(
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"检索之后的总结==>>:{aimessages}")
|
||||
logger.info(f"Summary after retrieval: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
@@ -358,7 +373,7 @@ async def Retrieve_Summary(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索总结', duration)
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
@@ -384,8 +399,9 @@ async def Input_Summary(
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
@@ -397,6 +413,7 @@ async def Input_Summary(
|
||||
search_switch: Search switch value for routing ('2' for summaries only)
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
@@ -406,21 +423,16 @@ async def Input_Summary(
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# Initialize variables to avoid UnboundLocalError
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Check if llm_client is None
|
||||
if llm_client is None:
|
||||
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
@@ -479,7 +491,7 @@ async def Input_Summary(
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
'''检索'''
|
||||
# Retrieval
|
||||
if search_switch == '2':
|
||||
search_params["include"] = ["summaries"]
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
@@ -509,12 +521,16 @@ async def Input_Summary(
|
||||
except:
|
||||
retrieve_info=''
|
||||
raw_results=['']
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
logger.info("Input_Summary: 使用 summary 进行检索")
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
logger.info("Input_Summary: Using summary for retrieval")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -547,7 +563,7 @@ async def Input_Summary(
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
@@ -587,7 +603,7 @@ async def Input_Summary(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
log_time('Retrieval', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
||||
@@ -5,20 +5,19 @@ This module contains MCP tools for verifying retrieved data.
|
||||
"""
|
||||
import time
|
||||
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Verify_messages_deal,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Verify_messages_deal,
|
||||
)
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -30,6 +29,7 @@ async def Verify(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
@@ -42,6 +42,7 @@ async def Verify(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -91,8 +92,12 @@ async def Verify(
|
||||
|
||||
|
||||
|
||||
# Call verification workflow
|
||||
verify_tool = VerifyTool(system_prompt, messages)
|
||||
# Call verification workflow with LLM model ID from memory_config
|
||||
verify_tool = VerifyTool(
|
||||
system_prompt=system_prompt,
|
||||
verify_data=messages,
|
||||
llm_model_id=str(memory_config.llm_model_id)
|
||||
)
|
||||
verify_result = await verify_tool.verify()
|
||||
|
||||
# Parse LLM verification result with error handling
|
||||
@@ -118,7 +123,7 @@ async def Verify(
|
||||
"history": history,
|
||||
}
|
||||
|
||||
logger.info(f"验证==>>:{messages_deal}")
|
||||
logger.info(f"Verification result: {messages_deal}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
@@ -128,7 +133,7 @@ async def Verify(
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "数据验证",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.get("split_result", "unknown"),
|
||||
"reason": messages_deal.get("reason", ""),
|
||||
"query": query,
|
||||
@@ -166,4 +171,4 @@ async def Verify(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('验证', duration)
|
||||
log_time('Verification', duration)
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TypedDict, Annotated
|
||||
import os
|
||||
import logging
|
||||
|
||||
from jinja2 import Template
|
||||
from langchain_core.messages import AnyMessage
|
||||
from dotenv import load_dotenv
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_picture_config,
|
||||
get_voice_config,
|
||||
)
|
||||
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,6 +43,7 @@ class WriteState(TypedDict):
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
|
||||
class ReadState(TypedDict):
|
||||
'''
|
||||
@@ -53,6 +53,7 @@ class ReadState(TypedDict):
|
||||
loop_count:Traverse times
|
||||
search_switch:type
|
||||
config_id: configuration id for filtering results
|
||||
errors: list of errors that occurred during workflow execution
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
|
||||
name: str
|
||||
@@ -63,6 +64,7 @@ class ReadState(TypedDict):
|
||||
apply_id: str
|
||||
group_id: str
|
||||
config_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
|
||||
|
||||
class COUNTState:
|
||||
@@ -109,9 +111,17 @@ def deduplicate_entries(entries):
|
||||
|
||||
|
||||
|
||||
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
PROMPT_TICKET_EXTRACTION: Extraction prompt
|
||||
picture_model_name: Picture model name (required, no longer from global variables)
|
||||
"""
|
||||
try:
|
||||
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
|
||||
model_config = get_picture_config(picture_model_name)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
@@ -147,9 +157,15 @@ async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||
picture_text = json.loads(picture_text)
|
||||
return (picture_text['statement'])
|
||||
|
||||
async def Voice_recognize():
|
||||
async def Voice_recognize(voice_model_name: str):
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
voice_model_name: Voice model name (required, no longer from global variables)
|
||||
"""
|
||||
try:
|
||||
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
|
||||
model_config = get_voice_config(voice_model_name)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Any
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -119,11 +119,23 @@ async def Problem_Extension_messages_deal(context):
|
||||
extent_quest = []
|
||||
original = context.get('original', '')
|
||||
messages = context.get('context', '')
|
||||
messages = json.loads(messages)
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
# Handle empty or non-string messages
|
||||
if not messages:
|
||||
return extent_quest, original
|
||||
|
||||
if isinstance(messages, str):
|
||||
try:
|
||||
messages = json.loads(messages)
|
||||
except json.JSONDecodeError:
|
||||
# If JSON parsing fails, return empty list
|
||||
return extent_quest, original
|
||||
|
||||
if isinstance(messages, list):
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
return extent_quest, original
|
||||
|
||||
@@ -135,10 +147,19 @@ async def Retriev_messages_deal(context):
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}")
|
||||
if 'context' in context or 'original' in context:
|
||||
return context.get('context', {}), context.get('original', '')
|
||||
return content, original_value
|
||||
content = context.get('context', {})
|
||||
original = context.get('original', '')
|
||||
logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'")
|
||||
return content, original
|
||||
|
||||
# Return empty defaults if context is not a dict or doesn't have expected keys
|
||||
logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults")
|
||||
return {}, ''
|
||||
|
||||
async def Verify_messages_deal(context):
|
||||
'''
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
# 角色
|
||||
你是验证专家
|
||||
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析,是不是回答Query_Samll这个字段的问题
|
||||
你的目标是针对用户的输入Query_Small字段的提问和Answer_Small的回答分析,是不是回答Query_Small这个字段的问题
|
||||
|
||||
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
|
||||
## 工作步骤
|
||||
1. 获取所有的Query_Samll字段和Answer_Samll字段
|
||||
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
|
||||
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
|
||||
1. 获取所有的Query_Small字段和Answer_Small字段
|
||||
2. 分析Answer_Small的回复是不是和Query_Small有关系
|
||||
3. 判断Answer_Small和Query_Small之间分析出来的关系状态
|
||||
4. 如果是True保留,否则不要相对应的问题和回答
|
||||
5. 输出,需要严格按照模版
|
||||
输入:{{history}}
|
||||
历史消息:{"history":{{sentence}}}
|
||||
### 第一步 获取用户的输入
|
||||
获取用户的输入提取对应的Query_Samll和Answer_Samll
|
||||
获取用户的输入提取对应的Query_Small和Answer_Small
|
||||
### 第二步 分析验证
|
||||
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||
需要分析Query_Small和Answer_Small之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||
## 核心验证标准
|
||||
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Samll):
|
||||
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Small):
|
||||
1. 合理性标准(必须全部满足):
|
||||
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
|
||||
- 最小化:每个不同的子问题数量应尽可能少,通常不超过原问题关键要素数量的2倍(建议2-4个),避免冗余和不必要拆分。
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""
|
||||
Type classification utility for distinguishing read/write operations.
|
||||
"""
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.config import settings
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -19,12 +18,14 @@ class DistinguishTypeResponse(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
async def status_typle(messages: str) -> dict:
|
||||
async def status_typle(messages: str, llm_model_id: str) -> dict:
|
||||
"""
|
||||
Classify message type as read or write operation.
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
messages: User message to classify
|
||||
llm_model_id: LLM model ID to use (required, no longer from global variables)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'type' field with classification result
|
||||
@@ -42,8 +43,9 @@ async def status_typle(messages: str) -> dict:
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_model_id)
|
||||
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from typing import TypedDict, Annotated, List, Any
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
import asyncio
|
||||
import json
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
import os
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from langchain_core.messages import HumanMessage
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
from typing import Annotated, Any, List, TypedDict
|
||||
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from langchain_core.messages import AnyMessage, HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
@@ -31,8 +32,17 @@ class State(TypedDict):
|
||||
|
||||
|
||||
class VerifyTool:
|
||||
def __init__(self, system_prompt: str="", verify_data: Any=None):
|
||||
def __init__(self, system_prompt: str="", verify_data: Any=None, llm_model_id: str=None):
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for verification
|
||||
verify_data: Data to verify
|
||||
llm_model_id: LLM model ID (required, no longer from global variables)
|
||||
"""
|
||||
self.system_prompt = system_prompt
|
||||
self.llm_model_id = llm_model_id
|
||||
if isinstance(verify_data, str):
|
||||
self.verify_data = verify_data
|
||||
else:
|
||||
@@ -42,7 +52,11 @@ class VerifyTool:
|
||||
self.verify_data = str(verify_data)
|
||||
|
||||
async def model_1(self, state: State) -> State:
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
if not self.llm_model_id:
|
||||
raise ValueError("llm_model_id is required but not provided")
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
|
||||
)
|
||||
|
||||
@@ -1,91 +1,93 @@
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
"""
|
||||
Write Tools for Memory Knowledge Extraction Pipeline
|
||||
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation_all,
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
|
||||
|
||||
async def write(
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator)
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
Only MemoryConfig is needed - LLM and embedding clients are constructed
|
||||
internally from the config.
|
||||
|
||||
Args:
|
||||
content: 对话内容
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
group_id: 组ID
|
||||
ref_id: 参考ID,默认为 "wyl20251027"
|
||||
config_id: 配置ID,用于标记数据处理配置
|
||||
content: Dialogue content to process
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
"""
|
||||
# 如果提供了config_id,重新加载配置
|
||||
if config_id:
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
logger.info(f"Reloading configuration for config_id: {config_id}")
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.info(f"Configuration reloaded successfully for config_id: {config_id}")
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
config_id = str(memory_config.config_id)
|
||||
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
|
||||
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
|
||||
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
|
||||
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
|
||||
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
|
||||
logger.info(f"Config ID: {config_id if config_id else 'None'}")
|
||||
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
|
||||
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
|
||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
embedder_client = factory.get_embedder_client_from_config(memory_config)
|
||||
logger.info("LLM and embedding clients constructed")
|
||||
|
||||
# Initialize timing log
|
||||
log_file = "logs/time.log"
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
||||
f.write(f"Config: {memory_config.config_name} (ID: {config_id})\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
# 初始化客户端
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# 获取 embedder 配置
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
# Initialize Neo4j connector
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Step 1: 加载和分块数据
|
||||
|
||||
# Step 1: Load and chunk data
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
@@ -94,21 +96,21 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
config_id=config_id,
|
||||
)
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# Step 2: 初始化并运行 ExtractionOrchestrator
|
||||
|
||||
# Step 2: Initialize and run ExtractionOrchestrator
|
||||
step_start = time.time()
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
config=pipeline_config,
|
||||
embedding_id=embedding_model_id,
|
||||
)
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run returns a flat tuple of 7 values after deduplication
|
||||
|
||||
# Run the complete extraction pipeline
|
||||
(
|
||||
all_dialogue_nodes,
|
||||
all_chunk_nodes,
|
||||
@@ -118,14 +120,12 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_dedup_details,
|
||||
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# Step 8: Save all data to Neo4j database using graph models
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
# 运行索引创建
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
@@ -152,18 +152,16 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
|
||||
# Step 4: Generate Memory summaries and save to Neo4j
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||
)
|
||||
|
||||
# Save memory summaries to Neo4j as nodes
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
# Link summaries to statements via chunks for summary→entity queries
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
try:
|
||||
@@ -173,24 +171,15 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
finally:
|
||||
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# Add completion marker to log
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
|
||||
asyncio.run(write(content, ref_id="wyl20251027"))
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from neo4j import GraphDatabase
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from neo4j import GraphDatabase
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ------------------- 自包含路径解析 -------------------
|
||||
@@ -31,21 +32,54 @@ except NameError:
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
# 现在路径已经配置好,我们可以使用绝对导入
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
"""
|
||||
try:
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# 3. 构建Prompt
|
||||
tag_list_str = ", ".join(tags)
|
||||
@@ -140,8 +174,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
"""
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# 默认从环境变量读取
|
||||
group_id = group_id or DEFAULT_GROUP_ID
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
@@ -150,9 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
@@ -165,8 +197,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
if __name__ == "__main__":
|
||||
print("开始获取热门记忆标签...")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
group_id_to_query = SELECTED_GROUP_ID
|
||||
# 直接使用环境变量中的 group_id
|
||||
group_id_to_query = DEFAULT_GROUP_ID
|
||||
# 使用 asyncio.run 来执行异步主函数
|
||||
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ This script can be executed directly to generate a memory insight report for a t
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
@@ -17,12 +17,18 @@ src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class TagClassification(BaseModel):
|
||||
@@ -55,8 +61,33 @@ class MemoryInsight:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
@@ -294,8 +325,8 @@ async def main():
|
||||
"""
|
||||
Initializes and runs the memory insight analysis for a test user.
|
||||
"""
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
test_user_id = SELECTED_GROUP_ID
|
||||
# 默认从环境变量读取
|
||||
test_user_id = DEFAULT_GROUP_ID
|
||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
||||
|
||||
insight = None
|
||||
|
||||
@@ -6,10 +6,10 @@ Usage:
|
||||
python -m analytics.user_summary --user_id <group_id>
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -24,10 +24,17 @@ try:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,8 +49,33 @@ class UserSummary:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
@@ -107,8 +139,8 @@ class UserSummary:
|
||||
|
||||
|
||||
async def generate_user_summary(user_id: str | None = None) -> str:
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
effective_group_id = user_id or SELECTED_GROUP_ID
|
||||
# 默认从环境变量读取
|
||||
effective_group_id = user_id or DEFAULT_GROUP_ID
|
||||
svc = UserSummary(effective_group_id)
|
||||
try:
|
||||
return await svc.generate()
|
||||
@@ -139,7 +171,7 @@ if __name__ == "__main__":
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["user_summary"] = {
|
||||
"group_id": SELECTED_GROUP_ID,
|
||||
"group_id": DEFAULT_GROUP_ID,
|
||||
"summary": summary
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
{
|
||||
"llm_list": [
|
||||
{
|
||||
"llm_name": "qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_AGENT_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen3-14b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/deepseek-r1-0528-qwen3-8b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen3-235b-a22b-instruct-2507",
|
||||
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "DASHSCOPE_API_KEY"
|
||||
}
|
||||
,
|
||||
{
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "DASHSCOPE_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
},
|
||||
{
|
||||
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
}
|
||||
],
|
||||
"embedding_list": [
|
||||
{
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"api_base": "http://119.45.239.97:11434/v1",
|
||||
"dimension": 768
|
||||
},
|
||||
{
|
||||
"embedding_name": "openai/bge-m3",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"dimension": 1024
|
||||
}
|
||||
],
|
||||
"neo4j": {
|
||||
"uri": "bolt://1.94.111.67:7687",
|
||||
"username": "neo4j"
|
||||
},
|
||||
"chunker_list": [
|
||||
{
|
||||
"chunker_strategy": "TokenChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 56,
|
||||
"tokenizer_or_token_counter": "character"
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "SemanticChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1024,
|
||||
"threshold": 0.8,
|
||||
"min_sentences": 2,
|
||||
"skip_window": 1,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "LateChunker",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size": 2048,
|
||||
"min_characters_per_chunk": 24
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "NeuralChunker",
|
||||
"embedding_model": "mirth/chonky_modernbert_base_1",
|
||||
"min_characters_per_chunk": 24
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1000,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "HybridChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"threshold": 0.8,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "SentenceChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 2048,
|
||||
"chunk_overlap": 128,
|
||||
"min_sentences_per_chunk": 1,
|
||||
"min_characters_per_sentence": 12,
|
||||
"delim": [".", "!", "?", "\n"],
|
||||
"include_delim": "prev",
|
||||
"tokenizer_or_token_counter": "character"
|
||||
}
|
||||
],
|
||||
"langfuse": {
|
||||
"enabled": true
|
||||
},
|
||||
"agenta": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"selections": {
|
||||
"config_id": ""
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,34 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
@@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline(
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
@@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline(
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
@@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline(
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedding_id=embedding_name or SELECTED_EMBEDDING_ID
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -23,37 +23,38 @@ except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID
|
||||
)
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
f1_score,
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
avg_context_tokens
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
get_category_name,
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
get_category_name
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
load_locomo_data,
|
||||
extract_conversations,
|
||||
ingest_conversations_if_needed,
|
||||
load_locomo_data,
|
||||
resolve_temporal_references,
|
||||
select_and_format_information,
|
||||
retrieve_relevant_information,
|
||||
ingest_conversations_if_needed
|
||||
select_and_format_information,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
@@ -160,10 +161,16 @@ async def run_locomo_benchmark(
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize LLM client with database context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str:
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import f1_score, bleu1, jaccard
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
@@ -111,10 +113,14 @@ try:
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
@@ -429,13 +435,17 @@ async def run_enhanced_evaluation():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
@@ -458,10 +468,14 @@ async def run_enhanced_evaluation():
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -2,10 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
import statistics
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
@@ -13,16 +14,31 @@ except Exception:
|
||||
return None
|
||||
|
||||
import re
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
@@ -327,9 +343,13 @@ async def run_locomo_eval(
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -2,11 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,6 +16,7 @@ except Exception:
|
||||
|
||||
# 确保可以找到 src 及项目根路径
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
@@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
try:
|
||||
# 优先从 extraction_utils1 导入
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline, # type: ignore
|
||||
)
|
||||
except Exception:
|
||||
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
@@ -686,9 +701,13 @@ async def run_longmemeval_test(
|
||||
)
|
||||
|
||||
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
connector = Neo4jConnector()
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
@@ -748,10 +767,10 @@ async def run_longmemeval_test(
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
# for sm in summaries:
|
||||
# summary_text = str(sm.get("summary", "")).strip()
|
||||
# if summary_text:
|
||||
# contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要(最多3个)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
|
||||
@@ -2,11 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,15 +15,26 @@ except Exception:
|
||||
return None
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
@@ -647,9 +658,13 @@ async def run_longmemeval_test(
|
||||
items = qa_list[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化组件 - 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
connector = Neo4jConnector()
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -4,19 +4,35 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
@@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]:
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
@@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
@@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
@@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
@@ -2,10 +2,10 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,6 +15,7 @@ except Exception:
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
@@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
@@ -226,13 +237,17 @@ async def run_memsciqa_test(
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -5,18 +5,17 @@ OpenAI LLM 客户端实现
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.llm import RedBearLLM
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +42,7 @@ class OpenAIClient(LLMClient):
|
||||
|
||||
# 初始化 Langfuse 回调处理器(如果启用)
|
||||
self.langfuse_handler = None
|
||||
if LANGFUSE_ENABLED:
|
||||
if settings.LANGFUSE_ENABLED:
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
self.langfuse_handler = CallbackHandler()
|
||||
|
||||
@@ -1,403 +0,0 @@
|
||||
"""
|
||||
MemSci 记忆系统主入口 - 重构版本
|
||||
|
||||
该模块是重构后的记忆系统主入口,使用新的模块化架构。
|
||||
旧版本入口(app/core/memory/src/main.py)已删除。
|
||||
|
||||
主要功能:
|
||||
1. 协调整个知识提取流水线
|
||||
2. 支持试运行模式和正常运行模式
|
||||
3. 使用重构后的 storage_services 模块
|
||||
4. 提供统一的配置管理和日志记录
|
||||
|
||||
作者:Lance77
|
||||
日期:2025-11-22
|
||||
"""
|
||||
|
||||
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
|
||||
import os
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
||||
os.environ["LANGCHAIN_TRACING"] = "false"
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 导入重构后的模块
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
|
||||
# 导入数据加载函数
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
get_chunked_dialogs_with_preprocessing,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def main(
|
||||
dialogue_text: Optional[str] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本
|
||||
|
||||
该函数是重构后的主入口,使用新的模块化架构。
|
||||
|
||||
Args:
|
||||
dialogue_text: 输入的对话文本(可选,用于试运行模式)
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
- False: 正常运行模式,保存到 Neo4j
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
|
||||
工作流程:
|
||||
1. 初始化客户端和配置
|
||||
2. 加载或准备数据
|
||||
3. 执行知识提取流水线
|
||||
4. 保存结果(正常模式)或输出结果(试运行模式)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("MemSci 知识提取流水线 - 重构版本")
|
||||
print("=" * 60)
|
||||
print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}")
|
||||
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
|
||||
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
|
||||
print("Using model ID:", config_defs.SELECTED_LLM_ID)
|
||||
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
|
||||
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
|
||||
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化日志
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 加载或准备数据
|
||||
logger.info("Loading data...")
|
||||
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
|
||||
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
|
||||
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
|
||||
step_start = time.time()
|
||||
|
||||
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
|
||||
# 试运行模式:处理前端传入的对话文本
|
||||
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
|
||||
import re
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
# 创建对话上下文和对话数据
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
user_id=config_defs.SELECTED_USER_ID,
|
||||
apply_id=config_defs.SELECTED_APPLY_ID,
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
else:
|
||||
# 正常运行模式:从 testdata.json 文件加载
|
||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||
logger.info("Loading data from testdata.json...")
|
||||
test_data_path = os.path.join(
|
||||
os.path.dirname(__file__), "data", "testdata.json"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_data_path):
|
||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
user_id=config_defs.SELECTED_USER_ID,
|
||||
apply_id=config_defs.SELECTED_APPLY_ID,
|
||||
indices=config_defs.SELECTED_TEST_DATA_INDICES,
|
||||
input_data_path=test_data_path,
|
||||
llm_client=llm_client,
|
||||
skip_cleaning=True,
|
||||
)
|
||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
# 从 runtime.json 加载配置(已经过数据库覆写)
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback, # 传递进度回调
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
|
||||
# 进度回调:正在知识抽取
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple
|
||||
# extraction_result 是一个包含 7 个元素的 tuple:
|
||||
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
|
||||
# statement_chunk_edges, statement_entity_edges, entity_edges)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# 进度回调:生成结果
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
|
||||
# 步骤 5: 保存结果或输出结果
|
||||
if is_pilot_run:
|
||||
logger.info("Pilot run mode: Skipping Neo4j save")
|
||||
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
|
||||
print("提取结果已生成,可在相关输出中查看。")
|
||||
else:
|
||||
logger.info("Normal mode: Saving to Neo4j...")
|
||||
step_start = time.time()
|
||||
|
||||
# 创建索引和约束
|
||||
try:
|
||||
from app.repositories.neo4j.create_indexes import (
|
||||
create_fulltext_indexes,
|
||||
create_unique_constraints,
|
||||
)
|
||||
await create_fulltext_indexes()
|
||||
await create_unique_constraints()
|
||||
logger.info("Successfully created indexes and constraints")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes/constraints: {e}")
|
||||
|
||||
# 保存数据到 Neo4j
|
||||
try:
|
||||
from app.repositories.neo4j.graph_saver import (
|
||||
save_dialog_and_statements_to_neo4j,
|
||||
)
|
||||
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_edges=entity_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
print("\n✓ 成功保存所有数据到 Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
print("\n⚠ 部分数据保存到 Neo4j 失败")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
|
||||
print(f"\n✗ 保存到 Neo4j 失败: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 6: 生成记忆摘要(可选)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import (
|
||||
add_memory_summary_statement_edges,
|
||||
)
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
if not is_pilot_run:
|
||||
# 保存记忆摘要到 Neo4j
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
await ms_connector.close()
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
|
||||
print(f"\n✗ 流水线执行失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 清理资源
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 记录总时间
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# 添加完成标记
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 流水线执行完成")
|
||||
print(f"✓ 总耗时: {total_time:.2f} 秒")
|
||||
print(f"✓ 详细日志: {log_file}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -20,14 +20,14 @@ Classes:
|
||||
MemorySummaryNode: Node representing a memory summary
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import re
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.utils.alias_utils import validate_aliases
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
def parse_historical_datetime(v):
|
||||
@@ -361,7 +361,7 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(..., description="Summary of the fact about this entity")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@ Classes:
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
"""Represents an extracted entity from dialogue.
|
||||
|
||||
@@ -1,31 +1,43 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dotenv import load_dotenv
|
||||
from datetime import datetime
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph_by_embedding, search_graph,
|
||||
search_graph_by_temporal, search_graph_by_keyword_temporal,
|
||||
search_graph_by_chunk_id
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.config_models import TemporalSearchParams
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
|
||||
ForgettingEngine,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_chunk_id,
|
||||
search_graph_by_embedding,
|
||||
search_graph_by_keyword_temporal,
|
||||
search_graph_by_temporal,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
@@ -131,7 +143,7 @@ def rerank_hybrid_results(
|
||||
|
||||
# Add keyword results with BM25 scores
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
@@ -139,7 +151,7 @@ def rerank_hybrid_results(
|
||||
|
||||
# Add or update with embedding results
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
# Update existing item with embedding score
|
||||
@@ -220,7 +232,7 @@ def rerank_with_forgetting_curve(
|
||||
(keyword_items, False), (embedding_items, True)
|
||||
):
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
existing = combined_items.get(item_id)
|
||||
@@ -266,26 +278,25 @@ def rerank_with_forgetting_curve(
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"):
|
||||
"""Log search query information to file"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
search_type: Type of search (keyword, embedding, hybrid)
|
||||
group_id: Group identifier for filtering
|
||||
limit: Maximum number of results
|
||||
include: List of result types to include
|
||||
log_file: Deprecated parameter, kept for backward compatibility
|
||||
"""
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
log_entry = {
|
||||
"timestamp": timestamp,
|
||||
# "query": query_text,
|
||||
"query": cleaned_query,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id,
|
||||
"limit": limit,
|
||||
"include": include
|
||||
}
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
||||
|
||||
logger.info(f"Search logged: {query_text} ({search_type})")
|
||||
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
f"group_id={group_id}, limit={limit}, include={include}"
|
||||
)
|
||||
|
||||
|
||||
def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
@@ -315,229 +326,229 @@ def apply_reranker_placeholder(
|
||||
If config enables reranker, annotate items with a final_score equal to combined_score
|
||||
and keep ordering. This is a no-op reranker to be replaced later.
|
||||
"""
|
||||
try:
|
||||
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
if not rc or not rc.get("enabled", False):
|
||||
return results
|
||||
# try:
|
||||
# rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
# except Exception as e:
|
||||
# logger.debug(f"Failed to load reranker config: {e}")
|
||||
# rc = {}
|
||||
# if not rc or not rc.get("enabled", False):
|
||||
# return results
|
||||
|
||||
top_k = int(rc.get("top_k", 100))
|
||||
model_name = rc.get("model", "placeholder")
|
||||
# top_k = int(rc.get("top_k", 100))
|
||||
# model_name = rc.get("model", "placeholder")
|
||||
|
||||
for cat, items in results.items():
|
||||
head = items[:top_k]
|
||||
for it in head:
|
||||
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
it["final_score"] = base
|
||||
it["reranker_model"] = model_name
|
||||
# Keep overall order by final_score if present, otherwise combined/score
|
||||
results[cat] = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
reverse=True,
|
||||
)
|
||||
# for cat, items in results.items():
|
||||
# head = items[:top_k]
|
||||
# for it in head:
|
||||
# base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
# it["final_score"] = base
|
||||
# it["reranker_model"] = model_name
|
||||
# # Keep overall order by final_score if present, otherwise combined/score
|
||||
# results[cat] = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
# reverse=True,
|
||||
# )
|
||||
return results
|
||||
|
||||
|
||||
async def apply_llm_reranker(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
reranker_client: Optional[Any] = None,
|
||||
llm_weight: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Apply LLM-based reranking to search results.
|
||||
# async def apply_llm_reranker(
|
||||
# results: Dict[str, List[Dict[str, Any]]],
|
||||
# query_text: str,
|
||||
# reranker_client: Optional[Any] = None,
|
||||
# llm_weight: Optional[float] = None,
|
||||
# top_k: Optional[int] = None,
|
||||
# batch_size: Optional[int] = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
Args:
|
||||
results: Search results organized by category
|
||||
query_text: Original search query
|
||||
reranker_client: Optional pre-initialized reranker client
|
||||
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
top_k: Maximum number of items to rerank per category
|
||||
batch_size: Number of items to process concurrently
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
# reranker_client: Optional pre-initialized reranker client
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
Returns:
|
||||
Reranked results with final_score and reranker_model fields
|
||||
"""
|
||||
# Load reranker configuration from runtime.json
|
||||
try:
|
||||
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
# # Load reranker configuration from runtime.json
|
||||
# # try:
|
||||
# # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
# Check if reranking is enabled
|
||||
enabled = rc.get("enabled", False)
|
||||
if not enabled:
|
||||
logger.debug("LLM reranking is disabled in configuration")
|
||||
return results
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
# Load configuration parameters with defaults
|
||||
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
# Initialize reranker client if not provided
|
||||
if reranker_client is None:
|
||||
try:
|
||||
reranker_client = get_reranker_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
return results
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
# reranker_client = get_reranker_client()
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
# Get model name for metadata
|
||||
model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
# Process each category
|
||||
reranked_results = {}
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
items = results.get(category, [])
|
||||
if not items:
|
||||
reranked_results[category] = []
|
||||
continue
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# items = results.get(category, [])
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
# Select top K items by combined_score for reranking
|
||||
sorted_items = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
reverse=True
|
||||
)
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
top_items = sorted_items[:top_k]
|
||||
remaining_items = sorted_items[top_k:]
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
# Extract text content from each item
|
||||
def extract_text(item: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a result item."""
|
||||
# Try different text fields based on category
|
||||
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
return str(text).strip()
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
# Batch items for concurrent processing
|
||||
batches = []
|
||||
for i in range(0, len(top_items), batch_size):
|
||||
batch = top_items[i:i + batch_size]
|
||||
batches.append(batch)
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
# Process batches concurrently
|
||||
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Process a batch of items with LLM relevance scoring."""
|
||||
scored_batch = []
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
for item in batch:
|
||||
item_text = extract_text(item)
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
# Skip items with no text
|
||||
if not item_text:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = 0.0
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
continue
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["llm_relevance_score"] = 0.0
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
# Create relevance scoring prompt
|
||||
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
Query: {query_text}
|
||||
# Query: {query_text}
|
||||
|
||||
Result: {item_text}
|
||||
# Result: {item_text}
|
||||
|
||||
Respond with only a number between 0.0 and 1.0, where:
|
||||
- 0.0 means completely irrelevant
|
||||
- 1.0 means perfectly relevant
|
||||
# Respond with only a number between 0.0 and 1.0, where:
|
||||
# - 0.0 means completely irrelevant
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
Relevance score:"""
|
||||
# Relevance score:"""
|
||||
|
||||
# Send request to LLM
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = await reranker_client.chat(messages)
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
# Parse LLM response to extract relevance score
|
||||
response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
# Try to extract a float from the response
|
||||
try:
|
||||
# Remove any non-numeric characters except decimal point
|
||||
import re
|
||||
score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
if score_match:
|
||||
llm_score = float(score_match.group(1))
|
||||
# Clamp to [0.0, 1.0]
|
||||
llm_score = max(0.0, min(1.0, llm_score))
|
||||
else:
|
||||
raise ValueError("No numeric score found in response")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
llm_score = None
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
# import re
|
||||
# score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
# if score_match:
|
||||
# llm_score = float(score_match.group(1))
|
||||
# # Clamp to [0.0, 1.0]
|
||||
# llm_score = max(0.0, min(1.0, llm_score))
|
||||
# else:
|
||||
# raise ValueError("No numeric score found in response")
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
# Calculate final score
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
if llm_score is not None:
|
||||
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
item_copy["llm_relevance_score"] = llm_score
|
||||
else:
|
||||
# Use combined_score as fallback
|
||||
final_score = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
# else:
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
item_copy["final_score"] = final_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
return scored_batch
|
||||
# return scored_batch
|
||||
|
||||
# Process all batches concurrently
|
||||
try:
|
||||
batch_tasks = [process_batch(batch) for batch in batches]
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Merge batch results
|
||||
scored_items = []
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"Batch processing failed: {result}")
|
||||
continue
|
||||
scored_items.extend(result)
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
# if isinstance(result, Exception):
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
# Add remaining items (not in top K) with their combined_score as final_score
|
||||
for item in remaining_items:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_items.append(item_copy)
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
# Sort all items by final_score in descending order
|
||||
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
reranked_results[category] = scored_items
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# Return original items with combined_score as final_score
|
||||
for item in items:
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item["final_score"] = combined_score
|
||||
item["reranker_model"] = model_name
|
||||
reranked_results[category] = items
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
# for item in items:
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
return reranked_results
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
@@ -547,6 +558,7 @@ async def run_hybrid_search(
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -554,10 +566,14 @@ async def run_hybrid_search(
|
||||
"""
|
||||
|
||||
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing embedding_model_id and config_id
|
||||
"""
|
||||
# Start overall timing
|
||||
search_start_time = time.time()
|
||||
latency_metrics = {}
|
||||
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
@@ -610,7 +626,9 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
@@ -672,7 +690,7 @@ async def run_hybrid_search(
|
||||
if use_forgetting_rerank:
|
||||
# Load forgetting parameters from pipeline config
|
||||
try:
|
||||
pc = get_pipeline_config()
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
@@ -700,16 +718,16 @@ async def run_hybrid_search(
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
if use_llm_rerank:
|
||||
try:
|
||||
reranked_results = await apply_llm_reranker(
|
||||
results=reranked_results,
|
||||
query_text=query_text,
|
||||
)
|
||||
llm_rerank_applied = True
|
||||
logger.info("LLM reranking applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
# if use_llm_rerank:
|
||||
# try:
|
||||
# reranked_results = await apply_llm_reranker(
|
||||
# results=reranked_results,
|
||||
# query_text=query_text,
|
||||
# )
|
||||
# llm_rerank_applied = True
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
@@ -759,18 +777,11 @@ async def run_hybrid_search(
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
|
||||
completion_log = {
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"status": "completed",
|
||||
"result_counts": result_counts,
|
||||
"output_file": output_path,
|
||||
"latency_metrics": latency_metrics
|
||||
}
|
||||
|
||||
with open("search_log.txt", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(completion_log, ensure_ascii=False) + "\n")
|
||||
# Log completion using the standard logger
|
||||
logger.info(
|
||||
f"Search completed: query='{query_text}', type={search_type}, "
|
||||
f"result_counts={result_counts}, latency={latency_metrics}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -892,89 +903,95 @@ async def search_chunk_by_chunk_id(
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the hybrid graph search CLI.
|
||||
# def main():
|
||||
# """Main entry point for the hybrid graph search CLI.
|
||||
|
||||
Parses command line arguments and executes search with specified parameters.
|
||||
Supports keyword, embedding, and hybrid search modes.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
parser.add_argument(
|
||||
"--query", "-q", required=True, help="Free-text query to search"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search-type",
|
||||
"-t",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
default="hybrid",
|
||||
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-name",
|
||||
"-m",
|
||||
default="openai/nomic-embed-text:v1.5",
|
||||
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-id",
|
||||
"-g",
|
||||
default=None,
|
||||
help="Optional group_id to filter results (default: None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
"-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Max number of results per type (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
"-i",
|
||||
nargs="+",
|
||||
default=["statements", "chunks", "entities", "summaries"],
|
||||
choices=["statements", "chunks", "entities", "summaries"],
|
||||
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="search_results.json",
|
||||
help="Path to save the search results JSON (default: search_results.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rerank-alpha",
|
||||
"-a",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forgetting-rerank",
|
||||
action="store_true",
|
||||
help="Apply forgetting curve during reranking for hybrid search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-rerank",
|
||||
action="store_true",
|
||||
help="Apply LLM-based reranking for hybrid search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# Parses command line arguments and executes search with specified parameters.
|
||||
# Supports keyword, embedding, and hybrid search modes.
|
||||
# """
|
||||
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
# parser.add_argument(
|
||||
# "--query", "-q", required=True, help="Free-text query to search"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--search-type",
|
||||
# "-t",
|
||||
# choices=["keyword", "embedding", "hybrid"],
|
||||
# default="hybrid",
|
||||
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--config-id",
|
||||
# "-c",
|
||||
# type=int,
|
||||
# required=True,
|
||||
# help="Database configuration ID (required)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--group-id",
|
||||
# "-g",
|
||||
# default=None,
|
||||
# help="Optional group_id to filter results (default: None)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--limit",
|
||||
# "-k",
|
||||
# type=int,
|
||||
# default=5,
|
||||
# help="Max number of results per type (default: 5)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--include",
|
||||
# "-i",
|
||||
# nargs="+",
|
||||
# default=["statements", "chunks", "entities", "summaries"],
|
||||
# choices=["statements", "chunks", "entities", "summaries"],
|
||||
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--output",
|
||||
# "-o",
|
||||
# default="search_results.json",
|
||||
# help="Path to save the search results JSON (default: search_results.json)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--rerank-alpha",
|
||||
# "-a",
|
||||
# type=float,
|
||||
# default=0.6,
|
||||
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--forgetting-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply forgetting curve during reranking for hybrid search.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--llm-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply LLM-based reranking for hybrid search.",
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
run_hybrid_search(
|
||||
query_text=args.query,
|
||||
search_type=args.search_type,
|
||||
group_id=args.group_id,
|
||||
limit=args.limit,
|
||||
include=args.include,
|
||||
output_path=args.output,
|
||||
rerank_alpha=args.rerank_alpha,
|
||||
use_forgetting_rerank=args.forgetting_rerank,
|
||||
use_llm_rerank=args.llm_rerank,
|
||||
)
|
||||
)
|
||||
# # Load memory config from database
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
|
||||
|
||||
# asyncio.run(
|
||||
# run_hybrid_search(
|
||||
# query_text=args.query,
|
||||
# search_type=args.search_type,
|
||||
# group_id=args.group_id,
|
||||
# limit=args.limit,
|
||||
# include=args.include,
|
||||
# output_path=args.output,
|
||||
# memory_config=memory_config,
|
||||
# rerank_alpha=args.rerank_alpha,
|
||||
# use_forgetting_rerank=args.forgetting_rerank,
|
||||
# use_llm_rerank=args.llm_rerank,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
去重功能函数
|
||||
"""
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from typing import List, Dict, Tuple, Any
|
||||
from app.core.memory.models.graph_models import(
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode
|
||||
)
|
||||
import os
|
||||
from datetime import datetime
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
"""统一实体类型:基于LLM建议或启发式规则选择最合适的类型。
|
||||
@@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
config: DedupConfig,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
|
||||
"""
|
||||
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
|
||||
@@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
"""
|
||||
llm_records: List[str] = []
|
||||
try:
|
||||
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
|
||||
enable_switch = (
|
||||
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
|
||||
)
|
||||
if not enable_switch:
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
|
||||
_defaults = DedupConfig()
|
||||
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
|
||||
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
|
||||
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
|
||||
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
|
||||
|
||||
# 动态导入 llm 客户端(修正导入路径)
|
||||
try:
|
||||
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
|
||||
get_llm_client_fn = llm_utils_mod.get_llm_client
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
|
||||
if not bool(config.enable_llm_dedup_blockwise):
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数
|
||||
block_size = config.llm_block_size
|
||||
block_concurrency = config.llm_block_concurrency
|
||||
pair_concurrency = config.llm_pair_concurrency
|
||||
max_rounds = config.llm_max_rounds
|
||||
|
||||
try:
|
||||
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
|
||||
@@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
# 获取 LLM 客户端
|
||||
try:
|
||||
llm_client = get_llm_client_fn()
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
|
||||
# 验证 LLM 客户端
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端未提供")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
llm_redirect, llm_records = await llm_fn(
|
||||
@@ -813,7 +799,8 @@ async def LLM_disamb_decision(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
config: DedupConfig,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
|
||||
"""
|
||||
预消歧阶段:对“同名但类型不同”的实体对调用LLM进行消歧,
|
||||
@@ -824,22 +811,16 @@ async def LLM_disamb_decision(
|
||||
disamb_records: List[str] = []
|
||||
blocked_pairs: set[tuple[str, str]] = set()
|
||||
try:
|
||||
enable_switch = (
|
||||
config.enable_llm_disambiguation
|
||||
if config is not None
|
||||
else DedupConfig().enable_llm_disambiguation
|
||||
)
|
||||
if not bool(enable_switch):
|
||||
if not bool(config.enable_llm_disambiguation):
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
|
||||
llm_disambiguate_pairs_iterative,
|
||||
)
|
||||
|
||||
# 获取 LLM 客户端并验证
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
# 验证 LLM 客户端
|
||||
if llm_client is None:
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None")
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端未提供")
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
|
||||
@@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges(
|
||||
report_append: bool = False,
|
||||
report_stage_notes: List[str] | None = None,
|
||||
dedup_config: DedupConfig | None = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[
|
||||
List[ExtractedEntityNode],
|
||||
List[StatementEntityEdge],
|
||||
@@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges(
|
||||
|
||||
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
|
||||
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
|
||||
)
|
||||
|
||||
# 2) 模糊匹配(本地规则)
|
||||
@@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges(
|
||||
|
||||
if should_trigger_llm:
|
||||
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
|
||||
)
|
||||
else:
|
||||
llm_decision_records = []
|
||||
|
||||
@@ -10,15 +10,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明
|
||||
_write_dedup_fusion_report,
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
get_dedup_candidates_for_entities, # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from app.repositories.neo4j.neo4j_connector import (
|
||||
Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
)
|
||||
|
||||
|
||||
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt,用于将任意类型的输入值解析为datetime对象(处理实体节点中的时间字段)
|
||||
@@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
dedup_config: DedupConfig | None = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
@@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
|
||||
|
||||
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges(
|
||||
union_entities,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
report_stage="第二层去重消歧",
|
||||
report_append=True,
|
||||
dedup_config=dedup_config,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
@@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: Optional[ExtractionPipelineConfig] = None,
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return(
|
||||
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
|
||||
"""
|
||||
|
||||
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
|
||||
# pipeline_config is required - caller must provide it
|
||||
if pipeline_config is None:
|
||||
try:
|
||||
pipeline_config = get_pipeline_config()
|
||||
except Exception:
|
||||
pipeline_config = None
|
||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
@@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
report_stage="第一层去重消歧",
|
||||
report_append=False,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
# 初始化第二层融合结果为第一层结果
|
||||
@@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
llm_client=llm_client,
|
||||
)
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
|
||||
@@ -19,48 +19,49 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
embedding_generation_all,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
|
||||
# 导入各个提取模块
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||
StatementExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
|
||||
TripletExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import (
|
||||
TemporalExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
|
||||
TripletExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||||
_write_extracted_result_summary,
|
||||
export_test_input_doc,
|
||||
)
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
|
||||
connector: Neo4jConnector,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化流水线编排器
|
||||
@@ -106,6 +108,7 @@ class ExtractionOrchestrator:
|
||||
progress_callback: 进度回调函数
|
||||
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
embedding_id: 嵌入模型ID,如果为 None 则从全局配置获取(向后兼容)
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
@@ -113,6 +116,7 @@ class ExtractionOrchestrator:
|
||||
self.config = config or ExtractionPipelineConfig()
|
||||
self.is_pilot_run = False # 默认非试运行模式
|
||||
self.progress_callback = progress_callback # 保存进度回调函数
|
||||
self.embedding_id = embedding_id # 保存嵌入模型ID
|
||||
|
||||
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||
@@ -398,7 +402,9 @@ class ExtractionOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
|
||||
completed_statements += 1
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
@@ -412,7 +418,9 @@ class ExtractionOrchestrator:
|
||||
d_idx, stmt_id = statement_metadata[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"陈述句处理异常: {result}")
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
triplet_maps[d_idx][stmt_id] = TripletExtractionResponse(triplets=[], entities=[])
|
||||
else:
|
||||
triplet_maps[d_idx][stmt_id] = result
|
||||
@@ -525,8 +533,8 @@ class ExtractionOrchestrator:
|
||||
temporal_maps[d_idx][stmt_id] = result
|
||||
|
||||
# 为 ATEMPORAL 陈述句添加空的时间范围
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
@@ -738,17 +746,14 @@ class ExtractionOrchestrator:
|
||||
logger.info("开始生成基础嵌入向量(陈述句、分块、对话)")
|
||||
|
||||
try:
|
||||
# 从 runtime.json 获取嵌入模型配置ID
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
embedding_id = config_defs.SELECTED_EMBEDDING_ID
|
||||
|
||||
if not embedding_id:
|
||||
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
|
||||
raise ValueError("未配置嵌入模型ID")
|
||||
# embedding_id is required - no fallback to global variable
|
||||
if not self.embedding_id:
|
||||
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
|
||||
raise ValueError("embedding_id is required but was not provided")
|
||||
|
||||
# 只生成陈述句、分块和对话的嵌入(不包括实体)
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation(
|
||||
dialog_data_list, embedding_id
|
||||
dialog_data_list, self.embedding_id
|
||||
)
|
||||
|
||||
# 统计生成结果
|
||||
@@ -792,17 +797,14 @@ class ExtractionOrchestrator:
|
||||
logger.info("开始生成实体嵌入向量")
|
||||
|
||||
try:
|
||||
# 从 runtime.json 获取嵌入模型配置ID
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
embedding_id = config_defs.SELECTED_EMBEDDING_ID
|
||||
|
||||
if not embedding_id:
|
||||
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
|
||||
# embedding_id is required - no fallback to global variable
|
||||
if not self.embedding_id:
|
||||
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
|
||||
return triplet_maps
|
||||
|
||||
# 生成实体嵌入
|
||||
updated_triplet_maps = await generate_entity_embeddings_from_triplets(
|
||||
triplet_maps, embedding_id
|
||||
triplet_maps, self.embedding_id
|
||||
)
|
||||
|
||||
logger.info("实体嵌入生成完成")
|
||||
@@ -1240,7 +1242,9 @@ class ExtractionOrchestrator:
|
||||
if self.is_pilot_run:
|
||||
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
|
||||
# 只执行第一层去重
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
@@ -1249,6 +1253,7 @@ class ExtractionOrchestrator:
|
||||
report_stage="第一层去重消歧(试运行)",
|
||||
report_append=False,
|
||||
dedup_config=self.config.deduplication,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
@@ -1280,6 +1285,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
self.config,
|
||||
self.connector,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 解包返回值
|
||||
@@ -1824,7 +1830,9 @@ async def get_chunked_dialogs(
|
||||
)
|
||||
|
||||
# 创建分块器并处理对话
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
@@ -1871,7 +1879,9 @@ def preprocess_data(
|
||||
经过清洗转换后的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 数据预处理 ===")
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
preprocessor = DataPreprocessor()
|
||||
try:
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
||||
@@ -1902,7 +1912,9 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
raise ValueError("预处理数据为空,无法进行分块")
|
||||
|
||||
all_chunked_dialogs: List[DialogData] = []
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
|
||||
for dialog_data in data:
|
||||
chunker = DialogueChunker(chunker_strategy, llm_client=llm_client)
|
||||
@@ -1963,7 +1975,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
# 步骤2: 语义剪枝
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
pruner = SemanticPruner(llm_client=llm_client)
|
||||
|
||||
# 记录单对话场景下剪枝前的消息数量
|
||||
@@ -1986,7 +2000,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
# 保存剪枝后的数据
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
pruned_output_path = settings.get_memory_output_path("pruned_data.json")
|
||||
dp = DataPreprocessor(output_file_path=pruned_output_path)
|
||||
dp.save_data(preprocessed_data, output_path=pruned_output_path)
|
||||
|
||||
@@ -5,11 +5,13 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
@@ -21,7 +23,9 @@ class EmbeddingGenerator:
|
||||
Args:
|
||||
embedding_id: 嵌入模型 ID
|
||||
"""
|
||||
embedder_config = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config = config_service.get_embedder_config(embedding_id)
|
||||
self.embedder_client = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_config),
|
||||
)
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
@@ -91,22 +87,17 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
|
||||
async def Memory_summary_generation(
|
||||
async def memory_summary_generation(
|
||||
chunked_dialogs: List[DialogData],
|
||||
llm_client,
|
||||
embedding_id,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
) -> List[MemorySummaryNode]:
|
||||
"""Generate memory summaries per chunk, embed them, and return nodes."""
|
||||
embedder_cfg_dict = get_embedder_config(embedding_id)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
|
||||
)
|
||||
|
||||
# Collect all tasks for parallel processing
|
||||
tasks = []
|
||||
for dialog in chunked_dialogs:
|
||||
for chunk in dialog.chunks:
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder))
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
|
||||
|
||||
# Process all chunks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
|
||||
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||
from app.core.memory.utils.data.ontology import (
|
||||
LABEL_DEFINITIONS,
|
||||
RelevenceInfo,
|
||||
StatementType,
|
||||
TemporalInfo,
|
||||
)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,35 +8,43 @@
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.config.get_data import get_data,get_data_statement,extract_and_process_changes
|
||||
from app.core.memory.utils.config.get_data import (
|
||||
extract_and_process_changes,
|
||||
get_data,
|
||||
get_data_statement,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
neo4j_query_all,
|
||||
neo4j_query_part,
|
||||
neo4j_statement_all,
|
||||
neo4j_statement_part,
|
||||
)
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConflictResultSchema,
|
||||
ReflexionResultSchema,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
@@ -152,12 +160,26 @@ class ReflectionEngine:
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
if self.llm_client is None:
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
# from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
# model_id = self.llm_client
|
||||
# self.llm_client = get_llm_client(model_id)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
model_id = self.llm_client
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
# self.llm_client = factory.get_llm_client(model_id)
|
||||
|
||||
# Use MemoryConfigService to get model config
|
||||
config_service = MemoryConfigService(db)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
@@ -165,7 +187,6 @@ class ReflectionEngine:
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
}
|
||||
|
||||
model_config = get_model_config(self.llm_client)
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
@@ -184,9 +205,15 @@ class ReflectionEngine:
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
)
|
||||
self.render_evaluate_prompt_func = render_evaluate_prompt
|
||||
|
||||
if self.render_reflexion_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
self.render_reflexion_prompt_func = render_reflexion_prompt
|
||||
|
||||
if self.conflict_schema is None:
|
||||
@@ -196,6 +223,9 @@ class ReflectionEngine:
|
||||
self.reflexion_schema = ReflexionResultSchema
|
||||
|
||||
if self.update_query is None:
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
)
|
||||
self.update_query = UPDATE_STATEMENT_INVALID_AT
|
||||
|
||||
self._lazy_init_done = True
|
||||
|
||||
@@ -4,10 +4,20 @@
|
||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.storage_services.search.semantic_search import (
|
||||
SemanticSearchStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SearchStrategy",
|
||||
@@ -34,7 +44,7 @@ async def run_hybrid_search(
|
||||
include: list[str] | None = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
embedding_id: str | None = None,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""运行混合搜索(向后兼容的函数式API)
|
||||
@@ -51,24 +61,26 @@ async def run_hybrid_search(
|
||||
include: 要包含的搜索类别列表
|
||||
alpha: BM25分数权重(0.0-1.0)
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
embedding_id: 嵌入模型ID
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 使用提供的embedding_id或默认值
|
||||
emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID
|
||||
if not memory_config:
|
||||
raise ValueError("memory_config is required for search")
|
||||
|
||||
# 初始化客户端
|
||||
connector = Neo4jConnector()
|
||||
embedder_config_dict = get_embedder_config(emb_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
|
||||
@@ -1,447 +0,0 @@
|
||||
|
||||
# TODO hybrid_chatbot.py 是一个独立的GUI演示应用,不是核心功能的一部分,可以考虑删除
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
import tkinter as tk
|
||||
from tkinter import scrolledtext, messagebox
|
||||
import threading
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
# Import our hybrid search functionality
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.config_models import LLMConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class HybridSearchChatbot:
|
||||
def __init__(self):
|
||||
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# Chat history
|
||||
self.chat_history = []
|
||||
|
||||
# Search configuration
|
||||
self.search_config = {
|
||||
"group_id": "group_wyl_25",
|
||||
"limit": 10,
|
||||
"include": ["statements", "chunks", "entities","summaries"],
|
||||
# "include": ["statements", "dialogues", "entities"],
|
||||
"rerank_alpha": 0.6
|
||||
}
|
||||
|
||||
# Setup GUI
|
||||
self.setup_gui()
|
||||
|
||||
def setup_gui(self):
|
||||
"""Setup the GUI interface"""
|
||||
self.root = tk.Tk()
|
||||
self.root.title("Hybrid Search Chatbot")
|
||||
self.root.geometry("800x600")
|
||||
|
||||
# Chat display area
|
||||
self.chat_display = scrolledtext.ScrolledText(
|
||||
self.root,
|
||||
wrap=tk.WORD,
|
||||
width=80,
|
||||
height=25,
|
||||
state=tk.DISABLED
|
||||
)
|
||||
self.chat_display.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
|
||||
|
||||
# Input frame
|
||||
input_frame = tk.Frame(self.root)
|
||||
input_frame.pack(padx=10, pady=5, fill=tk.X)
|
||||
|
||||
# User input
|
||||
self.user_input = tk.Entry(input_frame, font=("Arial", 12))
|
||||
self.user_input.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5))
|
||||
self.user_input.bind("<Return>", self.on_send_message)
|
||||
|
||||
# Send button
|
||||
self.send_button = tk.Button(
|
||||
input_frame,
|
||||
text="发送",
|
||||
command=self.on_send_message,
|
||||
font=("Arial", 12)
|
||||
)
|
||||
self.send_button.pack(side=tk.RIGHT)
|
||||
|
||||
# Status frame
|
||||
status_frame = tk.Frame(self.root)
|
||||
status_frame.pack(padx=10, pady=5, fill=tk.X)
|
||||
|
||||
# Status label
|
||||
self.status_label = tk.Label(
|
||||
status_frame,
|
||||
text="就绪",
|
||||
font=("Arial", 10),
|
||||
anchor="w"
|
||||
)
|
||||
self.status_label.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
||||
|
||||
# Search config button
|
||||
config_button = tk.Button(
|
||||
status_frame,
|
||||
text="搜索配置",
|
||||
command=self.show_config_dialog,
|
||||
font=("Arial", 10)
|
||||
)
|
||||
config_button.pack(side=tk.RIGHT)
|
||||
|
||||
# Add welcome message
|
||||
self.add_message("系统", "欢迎使用混合搜索聊天机器人!我可以基于知识图谱中的信息回答您的问题。")
|
||||
|
||||
def add_message(self, sender: str, message: str, metadata: Dict = None):
|
||||
"""Add a message to the chat display"""
|
||||
self.chat_display.config(state=tk.NORMAL)
|
||||
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
# Add sender and timestamp
|
||||
self.chat_display.insert(tk.END, f"[{timestamp}] {sender}:\n", "sender")
|
||||
|
||||
# Add message content
|
||||
self.chat_display.insert(tk.END, f"{message}\n", "message")
|
||||
|
||||
# Add metadata if available
|
||||
if metadata:
|
||||
self.chat_display.insert(tk.END, f" {metadata}\n", "metadata")
|
||||
|
||||
self.chat_display.insert(tk.END, "\n")
|
||||
self.chat_display.config(state=tk.DISABLED)
|
||||
self.chat_display.see(tk.END)
|
||||
|
||||
# Configure text tags for styling
|
||||
self.chat_display.tag_config("sender", foreground="blue", font=("Arial", 10, "bold"))
|
||||
self.chat_display.tag_config("message", foreground="black", font=("Arial", 10))
|
||||
self.chat_display.tag_config("metadata", foreground="gray", font=("Arial", 8))
|
||||
|
||||
def show_config_dialog(self):
|
||||
"""Show search configuration dialog"""
|
||||
config_window = tk.Toplevel(self.root)
|
||||
config_window.title("搜索配置")
|
||||
config_window.geometry("400x600")
|
||||
config_window.transient(self.root)
|
||||
config_window.grab_set()
|
||||
|
||||
# Current configuration display
|
||||
current_config_frame = tk.Frame(config_window)
|
||||
current_config_frame.pack(pady=10, padx=10, fill=tk.X)
|
||||
tk.Label(current_config_frame, text="当前配置:", font=("Arial", 10, "bold")).pack(anchor="w")
|
||||
current_text = f"Alpha: {self.search_config['rerank_alpha']}, 限制: {self.search_config['limit']}, 目标: {', '.join(self.search_config['include'])}"
|
||||
tk.Label(current_config_frame, text=current_text, font=("Arial", 9), fg="blue").pack(anchor="w")
|
||||
|
||||
# Alpha parameter
|
||||
tk.Label(config_window, text="重排权重 (Alpha):").pack(pady=(10, 5))
|
||||
alpha_var = tk.DoubleVar(value=self.search_config["rerank_alpha"])
|
||||
alpha_scale = tk.Scale(
|
||||
config_window,
|
||||
from_=0.0,
|
||||
to=1.0,
|
||||
resolution=0.1,
|
||||
orient=tk.HORIZONTAL,
|
||||
variable=alpha_var
|
||||
)
|
||||
alpha_scale.pack(pady=5, padx=20, fill=tk.X)
|
||||
tk.Label(config_window, text="0.0=纯语义搜索, 1.0=纯关键词搜索", font=("Arial", 8)).pack()
|
||||
|
||||
# Limit parameter
|
||||
tk.Label(config_window, text="搜索结果数量:").pack(pady=(20, 5))
|
||||
limit_var = tk.IntVar(value=self.search_config["limit"])
|
||||
limit_spinbox = tk.Spinbox(
|
||||
config_window,
|
||||
from_=1,
|
||||
to=50,
|
||||
textvariable=limit_var,
|
||||
width=10
|
||||
)
|
||||
limit_spinbox.pack(pady=5)
|
||||
|
||||
# Include options
|
||||
tk.Label(config_window, text="搜索目标:").pack(pady=(20, 5))
|
||||
include_frame = tk.Frame(config_window)
|
||||
include_frame.pack(pady=5)
|
||||
|
||||
include_vars = {}
|
||||
for option in ["statements", "chunks", "entities","summaries"]:
|
||||
var = tk.BooleanVar(value=option in self.search_config["include"])
|
||||
include_vars[option] = var
|
||||
tk.Checkbutton(
|
||||
include_frame,
|
||||
text=option,
|
||||
variable=var
|
||||
).pack(side=tk.LEFT, padx=10)
|
||||
|
||||
# Buttons
|
||||
button_frame = tk.Frame(config_window)
|
||||
button_frame.pack(pady=20)
|
||||
|
||||
def save_config():
|
||||
try:
|
||||
# Validate inputs
|
||||
alpha_value = alpha_var.get()
|
||||
limit_value = limit_var.get()
|
||||
include_list = [
|
||||
option for option, var in include_vars.items() if var.get()
|
||||
]
|
||||
|
||||
# Check if at least one search target is selected
|
||||
if not include_list:
|
||||
messagebox.showerror("配置错误", "请至少选择一个搜索目标!")
|
||||
return
|
||||
|
||||
# Update configuration
|
||||
self.search_config["rerank_alpha"] = alpha_value
|
||||
self.search_config["limit"] = limit_value
|
||||
self.search_config["include"] = include_list
|
||||
|
||||
config_window.destroy()
|
||||
self.add_message("系统",
|
||||
f"配置已更新: Alpha={alpha_value:.1f}, 限制={limit_value}, 目标={', '.join(include_list)}")
|
||||
|
||||
except Exception as e:
|
||||
messagebox.showerror("配置错误", f"保存配置时出错: {str(e)}")
|
||||
print(f"Config save error: {e}") # Debug output
|
||||
|
||||
tk.Button(button_frame, text="保存", command=save_config).pack(side=tk.LEFT, padx=5)
|
||||
tk.Button(button_frame, text="取消", command=config_window.destroy).pack(side=tk.LEFT, padx=5)
|
||||
|
||||
def on_send_message(self, event=None):
|
||||
"""Handle sending a message"""
|
||||
user_message = self.user_input.get().strip()
|
||||
if not user_message:
|
||||
return
|
||||
|
||||
# Clear input
|
||||
self.user_input.delete(0, tk.END)
|
||||
|
||||
# Add user message to display
|
||||
self.add_message("用户", user_message)
|
||||
|
||||
# Disable send button and show processing status
|
||||
self.send_button.config(state=tk.DISABLED)
|
||||
self.status_label.config(text="正在搜索和生成回复...")
|
||||
|
||||
# Process message in background thread
|
||||
threading.Thread(
|
||||
target=self.process_message_async,
|
||||
args=(user_message,),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
def process_message_async(self, user_message: str):
|
||||
"""Process message asynchronously"""
|
||||
try:
|
||||
# Run the async processing
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
response, metadata = loop.run_until_complete(
|
||||
self.process_message(user_message)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
# Update GUI in main thread
|
||||
self.root.after(0, self.on_response_ready, response, metadata)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理消息时出错: {str(e)}"
|
||||
self.root.after(0, self.on_error, error_msg)
|
||||
|
||||
async def process_message(self, user_message: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Process user message with hybrid search"""
|
||||
start_time = time.time()
|
||||
|
||||
# Perform hybrid search
|
||||
search_start = time.time()
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=user_message,
|
||||
search_type="hybrid",
|
||||
group_id=self.search_config["group_id"],
|
||||
limit=self.search_config["limit"],
|
||||
include=self.search_config["include"],
|
||||
output_path=None,
|
||||
rerank_alpha=self.search_config["rerank_alpha"]
|
||||
)
|
||||
search_time = time.time() - search_start
|
||||
|
||||
# Extract relevant information from search results
|
||||
context_info = self.extract_context_from_search(search_results)
|
||||
|
||||
# Generate response using LLM
|
||||
llm_start = time.time()
|
||||
response = await self.generate_response(user_message, context_info)
|
||||
llm_time = time.time() - llm_start
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"搜索时间": f"{search_time:.2f}s",
|
||||
"生成时间": f"{llm_time:.2f}s",
|
||||
"总时间": f"{total_time:.2f}s",
|
||||
"搜索结果": self.get_search_summary(search_results),
|
||||
"重排权重": self.search_config["rerank_alpha"]
|
||||
}
|
||||
|
||||
return response, metadata
|
||||
|
||||
def extract_context_from_search(self, search_results: Dict) -> str:
|
||||
"""Extract context information from search results"""
|
||||
if not search_results:
|
||||
return "未找到相关信息。"
|
||||
|
||||
context_parts = []
|
||||
|
||||
# Get reranked results if available, otherwise use individual results
|
||||
if "reranked_results" in search_results:
|
||||
results = search_results["reranked_results"]
|
||||
else:
|
||||
results = {}
|
||||
for key in ["keyword_search", "embedding_search"]:
|
||||
if key in search_results:
|
||||
for category, items in search_results[key].items():
|
||||
if category not in results:
|
||||
results[category] = []
|
||||
results[category].extend(items)
|
||||
|
||||
# Extract statements
|
||||
if "statements" in results and results["statements"]:
|
||||
statements = results["statements"][:5] # Top 5
|
||||
context_parts.append("相关陈述:")
|
||||
for i, stmt in enumerate(statements, 1):
|
||||
content = stmt.get("statement", "")
|
||||
score = stmt.get("combined_score", stmt.get("score", 0))
|
||||
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
|
||||
|
||||
# Extract chunks
|
||||
if "chunks" in results and results["chunks"]:
|
||||
chunks = results["chunks"][:3] # Top 3
|
||||
context_parts.append("\n相关对话:")
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
content = chunk.get("content", "")
|
||||
score = chunk.get("combined_score", chunk.get("score", 0))
|
||||
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
|
||||
|
||||
# Extract entities
|
||||
if "entities" in results and results["entities"]:
|
||||
entities = results["entities"][:5] # Top 5
|
||||
context_parts.append("\n相关实体:")
|
||||
entity_names = [ent.get("name", "") for ent in entities]
|
||||
context_parts.append(", ".join(entity_names))
|
||||
|
||||
return "\n".join(context_parts) if context_parts else "未找到相关信息。"
|
||||
|
||||
def get_search_summary(self, search_results: Dict) -> str:
|
||||
"""Get a summary of search results"""
|
||||
if not search_results:
|
||||
return "无结果"
|
||||
|
||||
summary_parts = []
|
||||
|
||||
if "combined_summary" in search_results:
|
||||
summary = search_results["combined_summary"]
|
||||
if "total_reranked_results" in summary:
|
||||
summary_parts.append(f"重排结果: {summary['total_reranked_results']}")
|
||||
if "total_keyword_results" in summary:
|
||||
summary_parts.append(f"关键词: {summary['total_keyword_results']}")
|
||||
if "total_embedding_results" in summary:
|
||||
summary_parts.append(f"语义: {summary['total_embedding_results']}")
|
||||
|
||||
return ", ".join(summary_parts) if summary_parts else "有结果"
|
||||
|
||||
async def generate_response(self, user_message: str, context: str) -> str:
|
||||
"""Generate response using LLM"""
|
||||
system_prompt = f"""你是一个智能助手,基于知识图谱中的信息回答用户问题。
|
||||
|
||||
以下是从知识图谱中检索到的相关信息:
|
||||
{context}
|
||||
|
||||
请基于这些信息回答用户的问题。如果信息不足,请诚实地说明。回答要自然、友好,并且准确。"""
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
response = self.llm_client.chat(
|
||||
messages=messages,
|
||||
)
|
||||
print(response)
|
||||
# Extract content from various possible response types
|
||||
# 1) LangChain AIMessage or similar object with `.content`
|
||||
if hasattr(response, 'content'):
|
||||
return getattr(response, 'content')
|
||||
|
||||
# 2) OpenAI-style response with `.choices`
|
||||
if hasattr(response, 'choices') and response.choices:
|
||||
first_choice = response.choices[0]
|
||||
# Newer clients may have `.message.content`, some have `.content` directly
|
||||
if hasattr(first_choice, 'message') and hasattr(first_choice.message, 'content'):
|
||||
return first_choice.message.content
|
||||
if hasattr(first_choice, 'content'):
|
||||
return first_choice.content
|
||||
|
||||
# 3) Dict-like responses
|
||||
if isinstance(response, dict):
|
||||
if 'content' in response:
|
||||
return response['content']
|
||||
if 'choices' in response and response['choices']:
|
||||
ch = response['choices'][0]
|
||||
if isinstance(ch, dict):
|
||||
if 'message' in ch and 'content' in ch['message']:
|
||||
return ch['message']['content']
|
||||
if 'content' in ch:
|
||||
return ch['content']
|
||||
|
||||
# 4) Fallback: if it's a plain string
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
|
||||
# Default fallback
|
||||
return "抱歉,我无法生成回复。"
|
||||
|
||||
except Exception as e:
|
||||
return f"生成回复时出错: {str(e)}"
|
||||
|
||||
def on_response_ready(self, response: str, metadata: Dict[str, Any]):
|
||||
"""Handle when response is ready"""
|
||||
self.add_message("助手", response, metadata)
|
||||
self.send_button.config(state=tk.NORMAL)
|
||||
self.status_label.config(text="就绪")
|
||||
self.user_input.focus()
|
||||
|
||||
def on_error(self, error_message: str):
|
||||
"""Handle errors"""
|
||||
self.add_message("系统", f" {error_message}")
|
||||
self.send_button.config(state=tk.NORMAL)
|
||||
self.status_label.config(text="就绪")
|
||||
self.user_input.focus()
|
||||
|
||||
def run(self):
|
||||
"""Start the chatbot"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the chatbot"""
|
||||
try:
|
||||
chatbot = HybridSearchChatbot()
|
||||
chatbot.run()
|
||||
except Exception as e:
|
||||
print(f"启动聊天机器人时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,408 +1,408 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""混合搜索策略
|
||||
# # -*- coding: utf-8 -*-
|
||||
# """混合搜索策略
|
||||
|
||||
结合关键词搜索和语义搜索的混合检索方法。
|
||||
支持结果重排序和遗忘曲线加权。
|
||||
"""
|
||||
# 结合关键词搜索和语义搜索的混合检索方法。
|
||||
# 支持结果重排序和遗忘曲线加权。
|
||||
# """
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import math
|
||||
from datetime import datetime
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
# from typing import List, Dict, Any, Optional
|
||||
# import math
|
||||
# from datetime import datetime
|
||||
# from app.core.logging_config import get_memory_logger
|
||||
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
# from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
# from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
# from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
# from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
# logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class HybridSearchStrategy(SearchStrategy):
|
||||
"""混合搜索策略
|
||||
# class HybridSearchStrategy(SearchStrategy):
|
||||
# """混合搜索策略
|
||||
|
||||
结合关键词搜索和语义搜索的优势:
|
||||
- 关键词搜索:精确匹配,适合已知术语
|
||||
- 语义搜索:语义理解,适合概念查询
|
||||
- 混合重排序:综合两种搜索的结果
|
||||
- 遗忘曲线:根据时间衰减调整相关性
|
||||
"""
|
||||
# 结合关键词搜索和语义搜索的优势:
|
||||
# - 关键词搜索:精确匹配,适合已知术语
|
||||
# - 语义搜索:语义理解,适合概念查询
|
||||
# - 混合重排序:综合两种搜索的结果
|
||||
# - 遗忘曲线:根据时间衰减调整相关性
|
||||
# """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
embedder_client: Optional[OpenAIEmbedderClient] = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
forgetting_config: Optional[ForgettingEngineConfig] = None
|
||||
):
|
||||
"""初始化混合搜索策略
|
||||
# def __init__(
|
||||
# self,
|
||||
# connector: Optional[Neo4jConnector] = None,
|
||||
# embedder_client: Optional[OpenAIEmbedderClient] = None,
|
||||
# alpha: float = 0.6,
|
||||
# use_forgetting_curve: bool = False,
|
||||
# forgetting_config: Optional[ForgettingEngineConfig] = None
|
||||
# ):
|
||||
# """初始化混合搜索策略
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
embedder_client: 嵌入模型客户端
|
||||
alpha: BM25分数权重(0.0-1.0),1-alpha为嵌入分数权重
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
forgetting_config: 遗忘引擎配置
|
||||
"""
|
||||
self.connector = connector
|
||||
self.embedder_client = embedder_client
|
||||
self.alpha = alpha
|
||||
self.use_forgetting_curve = use_forgetting_curve
|
||||
self.forgetting_config = forgetting_config or ForgettingEngineConfig()
|
||||
self._owns_connector = connector is None
|
||||
# Args:
|
||||
# connector: Neo4j连接器
|
||||
# embedder_client: 嵌入模型客户端
|
||||
# alpha: BM25分数权重(0.0-1.0),1-alpha为嵌入分数权重
|
||||
# use_forgetting_curve: 是否使用遗忘曲线
|
||||
# forgetting_config: 遗忘引擎配置
|
||||
# """
|
||||
# self.connector = connector
|
||||
# self.embedder_client = embedder_client
|
||||
# self.alpha = alpha
|
||||
# self.use_forgetting_curve = use_forgetting_curve
|
||||
# self.forgetting_config = forgetting_config or ForgettingEngineConfig()
|
||||
# self._owns_connector = connector is None
|
||||
|
||||
# 创建子策略
|
||||
self.keyword_strategy = KeywordSearchStrategy(connector=connector)
|
||||
self.semantic_strategy = SemanticSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
# # 创建子策略
|
||||
# self.keyword_strategy = KeywordSearchStrategy(connector=connector)
|
||||
# self.semantic_strategy = SemanticSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client
|
||||
# )
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
if self._owns_connector:
|
||||
self.connector = Neo4jConnector()
|
||||
self.keyword_strategy.connector = self.connector
|
||||
self.semantic_strategy.connector = self.connector
|
||||
return self
|
||||
# async def __aenter__(self):
|
||||
# """异步上下文管理器入口"""
|
||||
# if self._owns_connector:
|
||||
# self.connector = Neo4jConnector()
|
||||
# self.keyword_strategy.connector = self.connector
|
||||
# self.semantic_strategy.connector = self.connector
|
||||
# return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
if self._owns_connector and self.connector:
|
||||
await self.connector.close()
|
||||
# async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# """异步上下文管理器出口"""
|
||||
# if self._owns_connector and self.connector:
|
||||
# await self.connector.close()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> SearchResult:
|
||||
"""执行混合搜索
|
||||
# async def search(
|
||||
# self,
|
||||
# query_text: str,
|
||||
# group_id: Optional[str] = None,
|
||||
# limit: int = 50,
|
||||
# include: Optional[List[str]] = None,
|
||||
# **kwargs
|
||||
# ) -> SearchResult:
|
||||
# """执行混合搜索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
# Args:
|
||||
# query_text: 查询文本
|
||||
# group_id: 可选的组ID过滤
|
||||
# limit: 每个类别的最大结果数
|
||||
# include: 要包含的搜索类别列表
|
||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
# Returns:
|
||||
# SearchResult: 搜索结果对象
|
||||
# """
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
|
||||
# 从kwargs中获取参数
|
||||
alpha = kwargs.get("alpha", self.alpha)
|
||||
use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
|
||||
# # 从kwargs中获取参数
|
||||
# alpha = kwargs.get("alpha", self.alpha)
|
||||
# use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
# # 获取有效的搜索类别
|
||||
# include_list = self._get_include_list(include)
|
||||
|
||||
try:
|
||||
# 并行执行关键词搜索和语义搜索
|
||||
keyword_result = await self.keyword_strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
# try:
|
||||
# # 并行执行关键词搜索和语义搜索
|
||||
# keyword_result = await self.keyword_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
semantic_result = await self.semantic_strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
# semantic_result = await self.semantic_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# 重排序结果
|
||||
if use_forgetting:
|
||||
reranked_results = self._rerank_with_forgetting_curve(
|
||||
keyword_result=keyword_result,
|
||||
semantic_result=semantic_result,
|
||||
alpha=alpha,
|
||||
limit=limit
|
||||
)
|
||||
else:
|
||||
reranked_results = self._rerank_hybrid_results(
|
||||
keyword_result=keyword_result,
|
||||
semantic_result=semantic_result,
|
||||
alpha=alpha,
|
||||
limit=limit
|
||||
)
|
||||
# # 重排序结果
|
||||
# if use_forgetting:
|
||||
# reranked_results = self._rerank_with_forgetting_curve(
|
||||
# keyword_result=keyword_result,
|
||||
# semantic_result=semantic_result,
|
||||
# alpha=alpha,
|
||||
# limit=limit
|
||||
# )
|
||||
# else:
|
||||
# reranked_results = self._rerank_hybrid_results(
|
||||
# keyword_result=keyword_result,
|
||||
# semantic_result=semantic_result,
|
||||
# alpha=alpha,
|
||||
# limit=limit
|
||||
# )
|
||||
|
||||
# 创建元数据
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="hybrid",
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include_list,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting
|
||||
)
|
||||
# # 创建元数据
|
||||
# metadata = self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# include=include_list,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting
|
||||
# )
|
||||
|
||||
# 添加结果统计
|
||||
metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
|
||||
metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
|
||||
metadata["total_keyword_results"] = keyword_result.total_results()
|
||||
metadata["total_semantic_results"] = semantic_result.total_results()
|
||||
metadata["total_reranked_results"] = reranked_results.total_results()
|
||||
# # 添加结果统计
|
||||
# metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
|
||||
# metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
|
||||
# metadata["total_keyword_results"] = keyword_result.total_results()
|
||||
# metadata["total_semantic_results"] = semantic_result.total_results()
|
||||
# metadata["total_reranked_results"] = reranked_results.total_results()
|
||||
|
||||
reranked_results.metadata = metadata
|
||||
# reranked_results.metadata = metadata
|
||||
|
||||
logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
|
||||
return reranked_results
|
||||
# logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
|
||||
# return reranked_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"混合搜索失败: {e}", exc_info=True)
|
||||
# 返回空结果但包含错误信息
|
||||
return SearchResult(
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="hybrid",
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
# except Exception as e:
|
||||
# logger.error(f"混合搜索失败: {e}", exc_info=True)
|
||||
# # 返回空结果但包含错误信息
|
||||
# return SearchResult(
|
||||
# metadata=self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# error=str(e)
|
||||
# )
|
||||
# )
|
||||
|
||||
def _normalize_scores(
|
||||
self,
|
||||
results: List[Dict[str, Any]],
|
||||
score_field: str = "score"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""使用z-score标准化和sigmoid转换归一化分数
|
||||
# def _normalize_scores(
|
||||
# self,
|
||||
# results: List[Dict[str, Any]],
|
||||
# score_field: str = "score"
|
||||
# ) -> List[Dict[str, Any]]:
|
||||
# """使用z-score标准化和sigmoid转换归一化分数
|
||||
|
||||
Args:
|
||||
results: 结果列表
|
||||
score_field: 分数字段名
|
||||
# Args:
|
||||
# results: 结果列表
|
||||
# score_field: 分数字段名
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 归一化后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return results
|
||||
# Returns:
|
||||
# List[Dict[str, Any]]: 归一化后的结果列表
|
||||
# """
|
||||
# if not results:
|
||||
# return results
|
||||
|
||||
# 提取分数
|
||||
scores = []
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item.get(score_field)
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
scores.append(0.0)
|
||||
# # 提取分数
|
||||
# scores = []
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# score = item.get(score_field)
|
||||
# if score is not None and isinstance(score, (int, float)):
|
||||
# scores.append(float(score))
|
||||
# else:
|
||||
# scores.append(0.0)
|
||||
|
||||
if not scores or len(scores) == 1:
|
||||
# 单个分数或无分数,设置为1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
return results
|
||||
# if not scores or len(scores) == 1:
|
||||
# # 单个分数或无分数,设置为1.0
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# item[f"normalized_{score_field}"] = 1.0
|
||||
# return results
|
||||
|
||||
# 计算均值和标准差
|
||||
mean_score = sum(scores) / len(scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
std_dev = math.sqrt(variance)
|
||||
# # 计算均值和标准差
|
||||
# mean_score = sum(scores) / len(scores)
|
||||
# variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
# std_dev = math.sqrt(variance)
|
||||
|
||||
if std_dev == 0:
|
||||
# 所有分数相同,设置为1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
else:
|
||||
# z-score标准化 + sigmoid转换
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item[score_field]
|
||||
if score is None or not isinstance(score, (int, float)):
|
||||
score = 0.0
|
||||
z_score = (score - mean_score) / std_dev
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
# if std_dev == 0:
|
||||
# # 所有分数相同,设置为1.0
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# item[f"normalized_{score_field}"] = 1.0
|
||||
# else:
|
||||
# # z-score标准化 + sigmoid转换
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# score = item[score_field]
|
||||
# if score is None or not isinstance(score, (int, float)):
|
||||
# score = 0.0
|
||||
# z_score = (score - mean_score) / std_dev
|
||||
# normalized = 1 / (1 + math.exp(-z_score))
|
||||
# item[f"normalized_{score_field}"] = normalized
|
||||
|
||||
return results
|
||||
# return results
|
||||
|
||||
def _rerank_hybrid_results(
|
||||
self,
|
||||
keyword_result: SearchResult,
|
||||
semantic_result: SearchResult,
|
||||
alpha: float,
|
||||
limit: int
|
||||
) -> SearchResult:
|
||||
"""重排序混合搜索结果
|
||||
# def _rerank_hybrid_results(
|
||||
# self,
|
||||
# keyword_result: SearchResult,
|
||||
# semantic_result: SearchResult,
|
||||
# alpha: float,
|
||||
# limit: int
|
||||
# ) -> SearchResult:
|
||||
# """重排序混合搜索结果
|
||||
|
||||
Args:
|
||||
keyword_result: 关键词搜索结果
|
||||
semantic_result: 语义搜索结果
|
||||
alpha: BM25分数权重
|
||||
limit: 结果限制
|
||||
# Args:
|
||||
# keyword_result: 关键词搜索结果
|
||||
# semantic_result: 语义搜索结果
|
||||
# alpha: BM25分数权重
|
||||
# limit: 结果限制
|
||||
|
||||
Returns:
|
||||
SearchResult: 重排序后的结果
|
||||
"""
|
||||
reranked_data = {}
|
||||
# Returns:
|
||||
# SearchResult: 重排序后的结果
|
||||
# """
|
||||
# reranked_data = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
keyword_items = getattr(keyword_result, category, [])
|
||||
semantic_items = getattr(semantic_result, category, [])
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# keyword_items = getattr(keyword_result, category, [])
|
||||
# semantic_items = getattr(semantic_result, category, [])
|
||||
|
||||
# 归一化分数
|
||||
keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
# # 归一化分数
|
||||
# keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
# semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
|
||||
# 合并结果
|
||||
combined_items = {}
|
||||
# # 合并结果
|
||||
# combined_items = {}
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# # 添加关键词结果
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
|
||||
# 添加或更新语义结果
|
||||
for item in semantic_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# # 添加或更新语义结果
|
||||
# for item in semantic_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# 计算组合分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = item.get("bm25_score", 0)
|
||||
embedding_score = item.get("embedding_score", 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
item["combined_score"] = combined_score
|
||||
# # 计算组合分数
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
|
||||
# 排序并限制结果
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
# # 排序并限制结果
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
|
||||
reranked_data[category] = sorted_items
|
||||
# reranked_data[category] = sorted_items
|
||||
|
||||
return SearchResult(
|
||||
statements=reranked_data.get("statements", []),
|
||||
chunks=reranked_data.get("chunks", []),
|
||||
entities=reranked_data.get("entities", []),
|
||||
summaries=reranked_data.get("summaries", [])
|
||||
)
|
||||
# return SearchResult(
|
||||
# statements=reranked_data.get("statements", []),
|
||||
# chunks=reranked_data.get("chunks", []),
|
||||
# entities=reranked_data.get("entities", []),
|
||||
# summaries=reranked_data.get("summaries", [])
|
||||
# )
|
||||
|
||||
def _parse_datetime(self, value: Any) -> Optional[datetime]:
|
||||
"""解析日期时间字符串"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
s = value.strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(s)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
# def _parse_datetime(self, value: Any) -> Optional[datetime]:
|
||||
# """解析日期时间字符串"""
|
||||
# if value is None:
|
||||
# return None
|
||||
# if isinstance(value, datetime):
|
||||
# return value
|
||||
# if isinstance(value, str):
|
||||
# s = value.strip()
|
||||
# if not s:
|
||||
# return None
|
||||
# try:
|
||||
# return datetime.fromisoformat(s)
|
||||
# except Exception:
|
||||
# return None
|
||||
# return None
|
||||
|
||||
def _rerank_with_forgetting_curve(
|
||||
self,
|
||||
keyword_result: SearchResult,
|
||||
semantic_result: SearchResult,
|
||||
alpha: float,
|
||||
limit: int
|
||||
) -> SearchResult:
|
||||
"""使用遗忘曲线重排序混合搜索结果
|
||||
# def _rerank_with_forgetting_curve(
|
||||
# self,
|
||||
# keyword_result: SearchResult,
|
||||
# semantic_result: SearchResult,
|
||||
# alpha: float,
|
||||
# limit: int
|
||||
# ) -> SearchResult:
|
||||
# """使用遗忘曲线重排序混合搜索结果
|
||||
|
||||
Args:
|
||||
keyword_result: 关键词搜索结果
|
||||
semantic_result: 语义搜索结果
|
||||
alpha: BM25分数权重
|
||||
limit: 结果限制
|
||||
# Args:
|
||||
# keyword_result: 关键词搜索结果
|
||||
# semantic_result: 语义搜索结果
|
||||
# alpha: BM25分数权重
|
||||
# limit: 结果限制
|
||||
|
||||
Returns:
|
||||
SearchResult: 重排序后的结果
|
||||
"""
|
||||
engine = ForgettingEngine(self.forgetting_config)
|
||||
now_dt = datetime.now()
|
||||
# Returns:
|
||||
# SearchResult: 重排序后的结果
|
||||
# """
|
||||
# engine = ForgettingEngine(self.forgetting_config)
|
||||
# now_dt = datetime.now()
|
||||
|
||||
reranked_data = {}
|
||||
# reranked_data = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
keyword_items = getattr(keyword_result, category, [])
|
||||
semantic_items = getattr(semantic_result, category, [])
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# keyword_items = getattr(keyword_result, category, [])
|
||||
# semantic_items = getattr(semantic_result, category, [])
|
||||
|
||||
# 归一化分数
|
||||
keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
# # 归一化分数
|
||||
# keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
# semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
|
||||
# 合并结果
|
||||
combined_items = {}
|
||||
# # 合并结果
|
||||
# combined_items = {}
|
||||
|
||||
for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if not item_id:
|
||||
continue
|
||||
# for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if not item_id:
|
||||
# continue
|
||||
|
||||
if item_id not in combined_items:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# if item_id not in combined_items:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
|
||||
if is_embedding:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# 计算分数并应用遗忘权重
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# # 计算分数并应用遗忘权重
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
|
||||
# 计算时间衰减
|
||||
dt = self._parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
# # 计算时间衰减
|
||||
# dt = self._parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
memory_strength = 1.0 # 默认强度
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
# memory_strength = 1.0 # 默认强度
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days,
|
||||
# memory_strength=memory_strength
|
||||
# )
|
||||
|
||||
final_score = combined_score * forgetting_weight
|
||||
item["combined_score"] = final_score
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["time_elapsed_days"] = time_elapsed_days
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
# item["forgetting_weight"] = forgetting_weight
|
||||
# item["time_elapsed_days"] = time_elapsed_days
|
||||
|
||||
# 排序并限制结果
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
# # 排序并限制结果
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
|
||||
reranked_data[category] = sorted_items
|
||||
# reranked_data[category] = sorted_items
|
||||
|
||||
return SearchResult(
|
||||
statements=reranked_data.get("statements", []),
|
||||
chunks=reranked_data.get("chunks", []),
|
||||
entities=reranked_data.get("entities", []),
|
||||
summaries=reranked_data.get("summaries", [])
|
||||
)
|
||||
# return SearchResult(
|
||||
# statements=reranked_data.get("statements", []),
|
||||
# chunks=reranked_data.get("chunks", []),
|
||||
# entities=reranked_data.get("entities", []),
|
||||
# summaries=reranked_data.get("summaries", [])
|
||||
# )
|
||||
|
||||
@@ -5,15 +5,20 @@
|
||||
使用余弦相似度进行语义匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
@@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
"""
|
||||
try:
|
||||
# 从数据库读取嵌入器配置
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
|
||||
@@ -1,445 +0,0 @@
|
||||
# Memory 模块工具函数文档
|
||||
|
||||
本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
app/core/memory/utils/
|
||||
├── __init__.py # 包初始化文件,导出所有公共接口
|
||||
├── README.md # 本文档
|
||||
├── config/ # 配置管理模块
|
||||
│ ├── __init__.py # 配置模块初始化
|
||||
│ ├── config_utils.py # 配置管理工具
|
||||
│ ├── definitions.py # 全局定义和常量
|
||||
│ ├── overrides.py # 运行时配置覆写
|
||||
│ ├── get_data.py # 数据获取工具
|
||||
│ ├── litellm_config.py # LiteLLM 配置和监控
|
||||
│ └── config_optimization.py # 配置优化工具
|
||||
├── log/ # 日志管理模块
|
||||
│ ├── __init__.py # 日志模块初始化
|
||||
│ ├── logging_utils.py # 日志工具
|
||||
│ └── audit_logger.py # 审计日志
|
||||
├── prompt/ # 提示词管理模块
|
||||
│ ├── __init__.py # 提示词模块初始化
|
||||
│ ├── prompt_utils.py # 提示词渲染工具
|
||||
│ ├── template_render.py # 模板渲染工具
|
||||
│ └── prompts/ # Jinja2 提示词模板目录
|
||||
│ ├── entity_dedup.jinja2 # 实体去重提示词
|
||||
│ ├── extract_statement.jinja2 # 陈述句提取提示词
|
||||
│ ├── extract_temporal.jinja2 # 时间信息提取提示词
|
||||
│ ├── extract_triplet.jinja2 # 三元组提取提示词
|
||||
│ ├── memory_summary.jinja2 # 记忆摘要提示词
|
||||
│ ├── evaluate.jinja2 # 评估提示词
|
||||
│ ├── reflexion.jinja2 # 反思提示词
|
||||
│ ├── system.jinja2 # 系统提示词
|
||||
│ └── user.jinja2 # 用户提示词
|
||||
├── llm/ # LLM 工具模块
|
||||
│ ├── __init__.py # LLM 模块初始化
|
||||
│ └── llm_utils.py # LLM 客户端工具
|
||||
├── data/ # 数据处理模块
|
||||
│ ├── __init__.py # 数据模块初始化
|
||||
│ ├── text_utils.py # 文本处理工具
|
||||
│ ├── time_utils.py # 时间处理工具
|
||||
│ └── ontology.py # 本体定义(谓语、标签等)
|
||||
├── paths/ # 路径管理模块
|
||||
│ ├── __init__.py # 路径模块初始化
|
||||
│ └── output_paths.py # 输出路径管理
|
||||
├── visualization/ # 可视化模块
|
||||
│ ├── __init__.py # 可视化模块初始化
|
||||
│ └── forgetting_visualizer.py # 遗忘曲线可视化
|
||||
└── self_reflexion_utils/ # 自我反思工具模块
|
||||
├── __init__.py # 反思模块初始化
|
||||
├── evaluate.py # 冲突评估
|
||||
├── reflexion.py # 反思处理
|
||||
└── self_reflexion.py # 自我反思主逻辑
|
||||
```
|
||||
|
||||
## 模块分类
|
||||
|
||||
### 1. 配置管理(config/)
|
||||
|
||||
配置管理模块包含所有与配置相关的工具函数和定义。
|
||||
|
||||
#### config_utils.py
|
||||
提供配置加载和管理功能:
|
||||
- `get_model_config(model_id)` - 获取 LLM 模型配置
|
||||
- `get_embedder_config(embedding_id)` - 获取嵌入模型配置
|
||||
- `get_neo4j_config()` - 获取 Neo4j 数据库配置
|
||||
- `get_chunker_config(chunker_strategy)` - 获取分块策略配置
|
||||
- `get_pipeline_config()` - 获取流水线配置
|
||||
- `get_pruning_config()` - 获取语义剪枝配置
|
||||
- `get_picture_config()` - 获取图片模型配置
|
||||
- `get_voice_config()` - 获取语音模型配置
|
||||
|
||||
#### definitions.py
|
||||
全局定义和常量:
|
||||
- `CONFIG` - 基础配置(从 config.json 加载)
|
||||
- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载)
|
||||
- `PROJECT_ROOT` - 项目根目录路径
|
||||
- 各种选择配置常量(LLM、嵌入模型、分块策略等)
|
||||
- `reload_configuration_from_database(config_id)` - 动态重新加载配置
|
||||
|
||||
#### overrides.py
|
||||
运行时配置覆写:
|
||||
- `load_unified_config(project_root)` - 加载统一配置
|
||||
|
||||
#### get_data.py
|
||||
数据获取工具:
|
||||
- `get_data(host_id)` - 从 SQL 数据库获取数据
|
||||
|
||||
#### litellm_config.py
|
||||
LiteLLM 配置和监控:
|
||||
- `LiteLLMConfig` - LiteLLM 配置类
|
||||
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
|
||||
- `get_usage_summary()` - 获取使用统计摘要
|
||||
- `print_usage_summary()` - 打印使用统计
|
||||
- `get_instant_qps(module)` - 获取即时 QPS 数据
|
||||
- `print_instant_qps(module)` - 打印即时 QPS 信息
|
||||
|
||||
#### config_optimization.py
|
||||
配置优化工具:
|
||||
- 配置参数优化相关功能
|
||||
|
||||
### 3. LLM 工具(llm/)
|
||||
|
||||
LLM 工具模块包含所有与 LLM 客户端相关的工具函数。
|
||||
|
||||
#### llm_utils.py
|
||||
LLM 客户端工具:
|
||||
- `get_llm_client(llm_id)` - 获取 LLM 客户端实例
|
||||
- `get_reranker_client(rerank_id)` - 获取重排序客户端实例
|
||||
- `handle_response(response)` - 处理 LLM 响应
|
||||
|
||||
#### litellm_config.py
|
||||
LiteLLM 配置和监控:
|
||||
- `LiteLLMConfig` - LiteLLM 配置类
|
||||
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
|
||||
- `get_usage_summary()` - 获取使用统计摘要
|
||||
- `print_usage_summary()` - 打印使用统计
|
||||
- `get_instant_qps(module)` - 获取即时 QPS 数据
|
||||
- `print_instant_qps(module)` - 打印即时 QPS 信息
|
||||
|
||||
### 4. 提示词管理(prompt/)
|
||||
|
||||
提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。
|
||||
|
||||
#### prompt_utils.py
|
||||
提示词渲染工具(使用 Jinja2 模板):
|
||||
- `get_prompts(message)` - 获取系统和用户提示词
|
||||
- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词
|
||||
- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词
|
||||
- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词
|
||||
- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词
|
||||
- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词
|
||||
- `prompt_env` - Jinja2 环境对象
|
||||
|
||||
#### template_render.py
|
||||
模板渲染工具(用于评估和反思):
|
||||
- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词
|
||||
- `render_reflexion_prompt(data, schema)` - 渲染反思提示词
|
||||
|
||||
#### prompts/
|
||||
Jinja2 模板文件目录,包含所有提示词模板
|
||||
|
||||
### 5. 数据处理(data/)
|
||||
|
||||
数据处理模块包含所有数据处理相关的工具函数。
|
||||
|
||||
#### text_utils.py
|
||||
文本处理工具:
|
||||
- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符
|
||||
- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询
|
||||
|
||||
#### time_utils.py
|
||||
时间处理工具:
|
||||
- `validate_date_format(date_str)` - 验证日期格式(YYYY-MM-DD)
|
||||
- `normalize_date(date_str)` - 标准化日期格式
|
||||
- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值)
|
||||
- `preprocess_date_string(date_str)` - 预处理日期字符串
|
||||
|
||||
#### ontology.py
|
||||
本体定义:
|
||||
- `PREDICATE_DEFINITIONS` - 谓语定义字典
|
||||
- `LABEL_DEFINITIONS` - 标签定义字典
|
||||
- `Predicate` - 谓语枚举
|
||||
- `StatementType` - 陈述句类型枚举
|
||||
- `TemporalInfo` - 时间信息枚举
|
||||
- `RelevenceInfo` - 相关性信息枚举
|
||||
|
||||
### 2. 日志管理(log/)
|
||||
|
||||
日志管理模块包含所有与日志记录相关的工具函数。
|
||||
|
||||
#### logging_utils.py
|
||||
日志工具:
|
||||
- `log_prompt_rendering(role, content)` - 记录提示词渲染
|
||||
- `log_template_rendering(template_name, context)` - 记录模板渲染
|
||||
- `log_time(operation, duration)` - 记录操作耗时
|
||||
- `prompt_logger` - 提示词日志记录器
|
||||
|
||||
#### audit_logger.py
|
||||
审计日志:
|
||||
- `audit_logger` - 审计日志记录器
|
||||
- 记录系统关键操作和安全事件
|
||||
|
||||
### 6. 自我反思工具(self_reflexion_utils/)
|
||||
|
||||
自我反思工具模块包含记忆冲突检测和反思处理功能。
|
||||
|
||||
#### evaluate.py
|
||||
冲突评估:
|
||||
- `conflict(evaluate_data, schema)` - 评估记忆冲突
|
||||
|
||||
#### reflexion.py
|
||||
反思处理:
|
||||
- `reflexion(data, schema)` - 执行反思处理
|
||||
|
||||
#### self_reflexion.py
|
||||
自我反思主逻辑:
|
||||
- `self_reflexion(...)` - 自我反思主函数
|
||||
|
||||
### 7. 数据模型
|
||||
|
||||
#### json_schema.py
|
||||
JSON Schema 数据模型:
|
||||
- `BaseDataSchema` - 基础数据模型
|
||||
- `ConflictResultSchema` - 冲突结果模型
|
||||
- `ConflictSchema` - 冲突模型
|
||||
- `ReflexionSchema` - 反思模型
|
||||
- `ResolvedSchema` - 解决方案模型
|
||||
- `ReflexionResultSchema` - 反思结果模型
|
||||
|
||||
#### messages.py
|
||||
API 消息模型:
|
||||
- `ConfigKey` - 配置键模型
|
||||
- `ChunkerStrategy` - 分块策略枚举
|
||||
- `ConfigParams` - 配置参数模型
|
||||
- `ConfigParamsCreate` - 创建配置参数模型
|
||||
- `ConfigUpdate` - 更新配置模型
|
||||
- `ConfigUpdateExtracted` - 更新萃取引擎配置模型
|
||||
- `ConfigUpdateForget` - 更新遗忘引擎配置模型
|
||||
- `ConfigPilotRun` - 试运行配置模型
|
||||
- `ConfigFilter` - 配置过滤模型
|
||||
- `ApiResponse` - API 响应模型
|
||||
- `ok(msg, data)` - 成功响应构造函数
|
||||
- `fail(msg, error_code, data)` - 失败响应构造函数
|
||||
|
||||
### 8. 可视化(visualization/)
|
||||
|
||||
可视化模块包含所有可视化相关的工具函数。
|
||||
|
||||
#### forgetting_visualizer.py
|
||||
遗忘曲线可视化:
|
||||
- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组
|
||||
- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线
|
||||
- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果
|
||||
- `visualize_forgetting_curve(...)` - 可视化遗忘曲线
|
||||
- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面
|
||||
- `create_comparison_visualization(...)` - 创建对比可视化
|
||||
- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件
|
||||
|
||||
### 9. 路径管理(paths/)
|
||||
|
||||
路径管理模块包含所有路径管理相关的工具函数。
|
||||
|
||||
#### output_paths.py
|
||||
输出路径管理:
|
||||
- `get_output_dir()` - 获取输出目录
|
||||
- `get_output_path(filename)` - 获取输出文件路径
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 配置管理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.config import get_model_config, get_pipeline_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
|
||||
# 获取模型配置
|
||||
model_config = get_model_config("model_id_123")
|
||||
|
||||
# 获取流水线配置
|
||||
pipeline_config = get_pipeline_config()
|
||||
|
||||
# 使用全局常量
|
||||
llm_id = SELECTED_LLM_ID
|
||||
```
|
||||
|
||||
### 日志管理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger
|
||||
|
||||
# 记录提示词渲染
|
||||
log_prompt_rendering('user', 'Hello, world!')
|
||||
|
||||
# 记录操作耗时
|
||||
log_time('extraction', 1.23)
|
||||
|
||||
# 使用审计日志
|
||||
audit_logger.info('User action performed')
|
||||
```
|
||||
|
||||
### LLM 工具
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.llm import get_llm_client
|
||||
|
||||
# 获取 LLM 客户端
|
||||
llm_client = get_llm_client("llm_id_456")
|
||||
|
||||
# 调用 LLM
|
||||
response = await llm_client.chat([
|
||||
{"role": "user", "content": "Hello"}
|
||||
])
|
||||
```
|
||||
|
||||
### 提示词渲染
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
|
||||
|
||||
# 渲染陈述句提取提示词
|
||||
prompt = await render_statement_extraction_prompt(
|
||||
chunk_content="对话内容...",
|
||||
definitions=LABEL_DEFINITIONS,
|
||||
json_schema=schema,
|
||||
granularity=2
|
||||
)
|
||||
```
|
||||
|
||||
### 数据处理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.data.time_utils import normalize_date
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
# 标准化日期
|
||||
normalized = normalize_date("2025/10/28") # 返回 "2025-10-28"
|
||||
|
||||
# 转义 Lucene 查询
|
||||
escaped = escape_lucene_query("user:admin AND status:active")
|
||||
```
|
||||
|
||||
### 运行时配置覆写
|
||||
|
||||
```python
|
||||
from app.core.memory.utils import apply_runtime_overrides_with_config_id
|
||||
|
||||
# 使用指定 config_id 覆写配置
|
||||
runtime_cfg = {"selections": {}}
|
||||
updated_cfg = apply_runtime_overrides_with_config_id(
|
||||
project_root="/path/to/project",
|
||||
runtime_cfg=runtime_cfg,
|
||||
config_id="config_123"
|
||||
)
|
||||
```
|
||||
|
||||
## 迁移说明
|
||||
|
||||
### 从旧路径迁移
|
||||
|
||||
如果你的代码使用了旧的导入路径,请按以下方式更新:
|
||||
|
||||
**旧路径(2024年11月之前):**
|
||||
```python
|
||||
from app.core.memory.src.utils.config_utils import get_model_config
|
||||
from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.src.data_config_api.utils.messages import ok, fail
|
||||
```
|
||||
|
||||
**中间路径(2024年11月):**
|
||||
```python
|
||||
from app.core.memory.utils.config_utils import get_model_config
|
||||
from app.core.memory.utils.logging_utils import log_prompt_rendering
|
||||
from app.schemas.memory_storage_schema import ok, fail
|
||||
```
|
||||
|
||||
**新路径(2024年11月27日之后):**
|
||||
```python
|
||||
# 配置相关
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import get_model_config # 简化导入
|
||||
|
||||
# 日志相关
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.log import log_prompt_rendering # 简化导入
|
||||
|
||||
# 其他工具
|
||||
from app.core.memory.utils import prompt_utils
|
||||
from app.schemas.memory_storage_schema import ok, fail
|
||||
```
|
||||
|
||||
### 目录结构重组(2024年11月27日)
|
||||
|
||||
utils 目录已按功能进行了完整的重组:
|
||||
|
||||
**重组前的结构:**
|
||||
- 所有文件都在 `app/core/memory/utils/` 根目录下
|
||||
|
||||
**重组后的结构:**
|
||||
- `config/` - 配置管理相关文件
|
||||
- `log/` - 日志管理相关文件
|
||||
- `prompt/` - 提示词管理相关文件
|
||||
- `llm/` - LLM 工具相关文件
|
||||
- `data/` - 数据处理相关文件
|
||||
- `paths/` - 路径管理相关文件
|
||||
- `visualization/` - 可视化相关文件
|
||||
- `self_reflexion_utils/` - 自我反思工具(已存在)
|
||||
|
||||
**导入路径变化:**
|
||||
```python
|
||||
# 旧导入方式
|
||||
from app.core.memory.utils.config_utils import get_model_config
|
||||
from app.core.memory.utils.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
# 新导入方式
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
# 或使用简化导入
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.log import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
```
|
||||
|
||||
## 维护指南
|
||||
|
||||
### 添加新工具函数
|
||||
|
||||
1. 在相应的模块文件中添加函数
|
||||
2. 在 `__init__.py` 中导出函数
|
||||
3. 在本 README 中添加文档
|
||||
4. 编写单元测试
|
||||
|
||||
### 删除旧工具函数
|
||||
|
||||
1. 确认没有代码使用该函数
|
||||
2. 从模块文件中删除函数
|
||||
3. 从 `__init__.py` 中删除导出
|
||||
4. 更新本 README
|
||||
|
||||
### 重构工具函数
|
||||
|
||||
1. 保持向后兼容性(使用别名或包装器)
|
||||
2. 更新所有使用该函数的代码
|
||||
3. 更新文档和测试
|
||||
4. 在适当时机删除旧版本
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码
|
||||
2. **文档完整性**:每个函数都应有清晰的文档字符串
|
||||
3. **类型注解**:使用类型注解提高代码可读性
|
||||
4. **错误处理**:工具函数应有适当的错误处理
|
||||
5. **测试覆盖**:所有工具函数都应有单元测试
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md)
|
||||
- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md)
|
||||
- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md)
|
||||
@@ -6,33 +6,26 @@
|
||||
|
||||
# 从子模块导出常用函数和常量,保持向后兼容
|
||||
from .config_utils import (
|
||||
get_model_config,
|
||||
get_embedder_config,
|
||||
get_neo4j_config,
|
||||
get_chunker_config,
|
||||
get_embedder_config,
|
||||
get_model_config,
|
||||
get_picture_config,
|
||||
get_pipeline_config,
|
||||
get_pruning_config,
|
||||
get_picture_config,
|
||||
get_voice_config,
|
||||
)
|
||||
from .definitions import (
|
||||
CONFIG,
|
||||
RUNTIME_CONFIG,
|
||||
PROJECT_ROOT,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_RERANK_ID,
|
||||
SELECTED_LLM_PICTURE_NAME,
|
||||
SELECTED_LLM_VOICE_NAME,
|
||||
REFLEXION_ENABLED,
|
||||
REFLEXION_ITERATION_PERIOD,
|
||||
REFLEXION_RANGE,
|
||||
REFLEXION_BASELINE,
|
||||
reload_configuration_from_database,
|
||||
)
|
||||
from .overrides import load_unified_config
|
||||
|
||||
# DEPRECATED: Global configuration variables removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
# from .definitions import (
|
||||
# CONFIG, # DEPRECATED - empty dict for backward compatibility
|
||||
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
|
||||
# PROJECT_ROOT, # Still needed for file paths
|
||||
# reload_configuration_from_database, # DEPRECATED - returns False
|
||||
# )
|
||||
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
|
||||
from .get_data import get_data
|
||||
|
||||
# litellm_config 需要时动态导入,避免循环依赖
|
||||
# from .litellm_config import (
|
||||
# LiteLLMConfig,
|
||||
@@ -47,29 +40,16 @@ __all__ = [
|
||||
# config_utils
|
||||
"get_model_config",
|
||||
"get_embedder_config",
|
||||
"get_neo4j_config",
|
||||
"get_chunker_config",
|
||||
"get_pipeline_config",
|
||||
"get_pruning_config",
|
||||
"get_picture_config",
|
||||
"get_voice_config",
|
||||
# definitions
|
||||
"CONFIG",
|
||||
"RUNTIME_CONFIG",
|
||||
"PROJECT_ROOT",
|
||||
"SELECTED_LLM_ID",
|
||||
"SELECTED_EMBEDDING_ID",
|
||||
"SELECTED_GROUP_ID",
|
||||
"SELECTED_RERANK_ID",
|
||||
"SELECTED_LLM_PICTURE_NAME",
|
||||
"SELECTED_LLM_VOICE_NAME",
|
||||
"REFLEXION_ENABLED",
|
||||
"REFLEXION_ITERATION_PERIOD",
|
||||
"REFLEXION_RANGE",
|
||||
"REFLEXION_BASELINE",
|
||||
"reload_configuration_from_database",
|
||||
# overrides
|
||||
"load_unified_config",
|
||||
# definitions (DEPRECATED - use MemoryConfig objects instead)
|
||||
# "CONFIG", # DEPRECATED
|
||||
# "RUNTIME_CONFIG", # DEPRECATED
|
||||
# "PROJECT_ROOT",
|
||||
# "reload_configuration_from_database", # DEPRECATED
|
||||
# get_data
|
||||
"get_data",
|
||||
# litellm_config - 需要时从 .litellm_config 直接导入
|
||||
|
||||
@@ -1,94 +1,74 @@
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
"""
|
||||
Configuration utilities - Backward compatibility layer
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import status
|
||||
DEPRECATED: These functions now require a db session parameter.
|
||||
New code should use MemoryConfigService(db) instance directly.
|
||||
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
DedupConfig,
|
||||
StatementExtractionConfig,
|
||||
ForgettingEngineConfig,
|
||||
)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db
|
||||
from app.models.models_model import ModelConfig, ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
def get_model_config(model_id: str, db: Session | None = None) -> dict:
|
||||
For functions that don't require db (get_pipeline_config, get_pruning_config),
|
||||
they are still re-exported here.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# These functions don't require db - safe to re-export as static methods
|
||||
get_pipeline_config = MemoryConfigService.get_pipeline_config
|
||||
get_pruning_config = MemoryConfigService.get_pruning_config
|
||||
|
||||
|
||||
def get_model_config(model_id: str, db=None):
|
||||
"""DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly."""
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
raise ValueError(
|
||||
"get_model_config now requires a db session. "
|
||||
"Use MemoryConfigService(db).get_model_config(model_id) directly."
|
||||
)
|
||||
return MemoryConfigService(db).get_model_config(model_id)
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
print(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
|
||||
# 从环境变量读取超时和重试配置
|
||||
from app.core.config import settings
|
||||
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免 LLM 请求超时
|
||||
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取,默认120秒
|
||||
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取,默认2次
|
||||
}
|
||||
# 写入model_config.log文件中
|
||||
with open("logs/model_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"模型ID: {model_id}\n")
|
||||
f.write(f"模型配置信息:\n{model_config}\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
|
||||
def get_embedder_config(embedding_id: str, db=None):
|
||||
"""DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."""
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
raise ValueError(
|
||||
"get_embedder_config now requires a db session. "
|
||||
"Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."
|
||||
)
|
||||
return MemoryConfigService(db).get_embedder_config(embedding_id)
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
|
||||
if not config:
|
||||
print(f"嵌入模型ID {embedding_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
# Ensure required field for RedBearModelConfig validation
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免嵌入服务请求超时
|
||||
"timeout": 120.0, # 嵌入服务超时时间(秒)
|
||||
"max_retries": 5, # 最大重试次数
|
||||
}
|
||||
# 写入embedder_config.log文件中
|
||||
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"嵌入模型ID: {embedding_id}\n")
|
||||
f.write(f"嵌入模型配置信息:\n{model_config}\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_neo4j_config() -> dict:
|
||||
"""Retrieves the Neo4j configuration from the config file."""
|
||||
return CONFIG.get("neo4j", {})
|
||||
def get_picture_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
"""Retrieves the configuration for a specific model from the config file.
|
||||
|
||||
.. deprecated::
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use database-backed model configuration instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_picture_config is deprecated and will be removed in a future version. "
|
||||
"Use database-backed model configuration instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
for model_config in CONFIG.get("picture_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
raise ValueError(f"Model '{llm_name}' not found in config.json")
|
||||
|
||||
|
||||
def get_voice_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
"""Retrieves the configuration for a specific model from the config file.
|
||||
|
||||
.. deprecated::
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use database-backed model configuration instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_voice_config is deprecated and will be removed in a future version. "
|
||||
"Use database-backed model configuration instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
for model_config in CONFIG.get("voice_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
@@ -96,20 +76,15 @@ def get_voice_config(llm_name: str) -> dict:
|
||||
|
||||
|
||||
def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
"""Retrieves the configuration for a specific chunker strategy.
|
||||
"""Retrieves the configuration for a specific chunker strategy."""
|
||||
|
||||
Enhancements:
|
||||
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
|
||||
- Falls back to the first available chunker config when the requested one is missing.
|
||||
"""
|
||||
# 1) Try to find exact match in config
|
||||
chunker_list = CONFIG.get("chunker_list", [])
|
||||
for chunker_config in chunker_list:
|
||||
if chunker_config.get("chunker_strategy") == chunker_strategy:
|
||||
return chunker_config
|
||||
|
||||
# 2) Provide sane defaults for newer strategies
|
||||
default_configs = {
|
||||
"RecursiveChunker": {
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
"LLMChunker": {
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
@@ -134,134 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
if chunker_strategy in default_configs:
|
||||
return default_configs[chunker_strategy]
|
||||
|
||||
# 3) Fallback: use first available config but tag with requested strategy
|
||||
if chunker_list:
|
||||
fallback = chunker_list[0].copy()
|
||||
fallback["chunker_strategy"] = chunker_strategy
|
||||
# Non-fatal notice for visibility in logs if any
|
||||
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
|
||||
return fallback
|
||||
|
||||
# 4) If no configs available at all
|
||||
raise ValueError(
|
||||
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
|
||||
f"Chunker '{chunker_strategy}' not found "
|
||||
)
|
||||
|
||||
|
||||
def get_pipeline_config() -> ExtractionPipelineConfig:
|
||||
"""Build ExtractionPipelineConfig using only runtime.json values.
|
||||
|
||||
Behavior:
|
||||
- Read `deduplication` section from runtime.json if present.
|
||||
- Read `statement_extraction` section from runtime.json if present.
|
||||
- Read `forgetting_engine` section from runtime.json if present.
|
||||
- If absent, check legacy top-level `enable_llm_dedup` key.
|
||||
- Do NOT fall back to environment variables.
|
||||
- Unspecified fields use model defaults defined in DedupConfig.
|
||||
"""
|
||||
dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {}
|
||||
stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {}
|
||||
forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {}
|
||||
|
||||
# Assemble kwargs from runtime.json only
|
||||
kwargs = {}
|
||||
# LLM switch: prefer new key, then legacy top-level, default False
|
||||
if "enable_llm_dedup_blockwise" in dedup_rc:
|
||||
kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise"))
|
||||
else:
|
||||
# Legacy top-level fallback inside runtime.json only
|
||||
legacy = RUNTIME_CONFIG.get("enable_llm_dedup")
|
||||
if legacy is not None:
|
||||
kwargs["enable_llm_dedup_blockwise"] = bool(legacy)
|
||||
else:
|
||||
kwargs["enable_llm_dedup_blockwise"] = False # default reserve
|
||||
# Disambiguation switch: only from runtime.json deduplication section
|
||||
if "enable_llm_disambiguation" in dedup_rc:
|
||||
kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation"))
|
||||
|
||||
# Optional LLM fallback gating
|
||||
if "enable_llm_fallback_only_on_borderline" in dedup_rc:
|
||||
kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline"))
|
||||
|
||||
# Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults
|
||||
for key in (
|
||||
"fuzzy_name_threshold_strict",
|
||||
"fuzzy_type_threshold_strict",
|
||||
"fuzzy_overall_threshold",
|
||||
"fuzzy_unknown_type_name_threshold",
|
||||
"fuzzy_unknown_type_type_threshold",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
# Optional weights and bonuses for overall scoring
|
||||
for key in (
|
||||
"name_weight",
|
||||
"desc_weight",
|
||||
"type_weight",
|
||||
"context_bonus",
|
||||
"llm_fallback_floor",
|
||||
"llm_fallback_ceiling",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
# Optional LLM iterative dedup parameters
|
||||
for key in (
|
||||
"llm_block_size",
|
||||
"llm_block_concurrency",
|
||||
"llm_pair_concurrency",
|
||||
"llm_max_rounds",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
dedup_config = DedupConfig(**kwargs)
|
||||
|
||||
# Build StatementExtractionConfig from runtime.json
|
||||
stmt_kwargs = {}
|
||||
for key in (
|
||||
"statement_granularity",
|
||||
"temperature",
|
||||
"include_dialogue_context",
|
||||
"max_dialogue_context_chars",
|
||||
):
|
||||
if key in stmt_rc:
|
||||
stmt_kwargs[key] = stmt_rc[key]
|
||||
stmt_config = StatementExtractionConfig(**stmt_kwargs)
|
||||
|
||||
# Build ForgettingEngineConfig from runtime.json
|
||||
forget_kwargs = {}
|
||||
for key in ("offset", "lambda_time", "lambda_mem"):
|
||||
if key in forget_rc:
|
||||
forget_kwargs[key] = forget_rc[key]
|
||||
forget_config = ForgettingEngineConfig(**forget_kwargs)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
|
||||
def get_pruning_config() -> dict:
|
||||
"""Retrieve semantic pruning config from runtime.json.
|
||||
|
||||
Returns a dict suitable for PruningConfig.model_validate.
|
||||
|
||||
Structure in runtime.json:
|
||||
{
|
||||
"pruning": {
|
||||
"enabled": true,
|
||||
"scene": "education" | "online_service" | "outbound",
|
||||
"threshold": 0.5
|
||||
}
|
||||
}
|
||||
"""
|
||||
pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {}
|
||||
|
||||
return {
|
||||
"pruning_switch": bool(pruning_rc.get("enabled", False)),
|
||||
"pruning_scene": pruning_rc.get("scene", "education"),
|
||||
"pruning_threshold": float(pruning_rc.get("threshold", 0.5)),
|
||||
}
|
||||
|
||||
@@ -1,360 +1,268 @@
|
||||
"""
|
||||
配置加载模块 - 三阶段架构(已迁移到统一配置管理)
|
||||
# """
|
||||
# 配置加载模块 - DEPRECATED
|
||||
|
||||
本模块现在使用全局配置管理系统 (app/core/config.py)
|
||||
来加载和管理配置,同时保持向后兼容性。
|
||||
# ⚠️ DEPRECATION NOTICE ⚠️
|
||||
# This module is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)
|
||||
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
|
||||
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# """
|
||||
# import json
|
||||
# import os
|
||||
# import threading
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Any, Dict, Optional
|
||||
|
||||
# Import unified configuration system
|
||||
try:
|
||||
from app.core.config import settings
|
||||
USE_UNIFIED_CONFIG = True
|
||||
except ImportError:
|
||||
USE_UNIFIED_CONFIG = False
|
||||
settings = None
|
||||
# #TODO: Fix this
|
||||
|
||||
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# __file__ = app/core/memory/utils/config/definitions.py
|
||||
# os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# os.path.dirname(...) = app/core/memory/utils
|
||||
# os.path.dirname(...) = app/core/memory
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# try:
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# 全局配置锁 - 用于线程安全
|
||||
_config_lock = threading.RLock()
|
||||
# # Import unified configuration system
|
||||
# try:
|
||||
# from app.core.config import settings
|
||||
# USE_UNIFIED_CONFIG = True
|
||||
# except ImportError:
|
||||
# USE_UNIFIED_CONFIG = False
|
||||
# settings = None
|
||||
|
||||
# 加载基础配置(config.json)- 使用全局配置系统
|
||||
if USE_UNIFIED_CONFIG:
|
||||
CONFIG = settings.load_memory_config()
|
||||
else:
|
||||
# Fallback to legacy loading
|
||||
config_path = os.path.join(PROJECT_ROOT, "config.json")
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
CONFIG = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
print("Warning: config.json not found or is malformed. Using default settings.")
|
||||
CONFIG = {}
|
||||
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# # __file__ = app/core/memory/utils/config/definitions.py
|
||||
# # os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# # os.path.dirname(...) = app/core/memory/utils
|
||||
# # os.path.dirname(...) = app/core/memory
|
||||
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
DEFAULT_VALUES = {
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"group_id": "group_123",
|
||||
"user_id": "default_user",
|
||||
"apply_id": "default_apply",
|
||||
"llm_agent_name": "openai/qwen-plus",
|
||||
"llm_verify_name": "openai/qwen-plus",
|
||||
"llm_image_recognition": "openai/qwen-plus",
|
||||
"llm_voice_recognition": "openai/qwen-plus",
|
||||
"prompt_level": "DEBUG",
|
||||
"reflexion_iteration_period": "3",
|
||||
"reflexion_range": "retrieval",
|
||||
"reflexion_baseline": "TIME",
|
||||
}
|
||||
# # DEPRECATED: Global configuration lock removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
|
||||
# # DEPRECATED: Legacy config.json loading removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# CONFIG = {}
|
||||
|
||||
# DEFAULT_VALUES = {
|
||||
# "llm_name": "openai/qwen-plus",
|
||||
# "embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
# "chunker_strategy": "RecursiveChunker",
|
||||
# "group_id": "group_123",
|
||||
# "user_id": "default_user",
|
||||
# "apply_id": "default_apply",
|
||||
# "llm_agent_name": "openai/qwen-plus",
|
||||
# "llm_verify_name": "openai/qwen-plus",
|
||||
# "llm_image_recognition": "openai/qwen-plus",
|
||||
# "llm_voice_recognition": "openai/qwen-plus",
|
||||
# "prompt_level": "DEBUG",
|
||||
# "reflexion_iteration_period": "3",
|
||||
# "reflexion_range": "retrieval",
|
||||
# "reflexion_baseline": "TIME",
|
||||
# }
|
||||
|
||||
# # DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
"""
|
||||
从 runtime.json 文件加载配置(通过统一配置加载器)
|
||||
# # 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
# def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
# """
|
||||
# DEPRECATED: Legacy runtime.json loading
|
||||
|
||||
使用 overrides.py 的统一配置加载器,按优先级加载:
|
||||
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id)
|
||||
2. 环境变量配置
|
||||
3. runtime.json 默认配置
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 运行时配置字典
|
||||
"""
|
||||
try:
|
||||
# 使用 overrides.py 的统一配置加载器
|
||||
from app.core.memory.utils.config.overrides import load_unified_config
|
||||
# Returns:
|
||||
# Dict[str, Any]: Empty configuration (legacy support only)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return {"selections": {}}
|
||||
|
||||
|
||||
# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# # 保留此函数仅为向后兼容
|
||||
# def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# DEPRECATED: Legacy database configuration loading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
# """
|
||||
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
# - Pass configuration objects as parameters instead of using global variables
|
||||
|
||||
# Args:
|
||||
# runtime_cfg: 运行时配置字典
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# # Keep minimal global state for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# global RUNTIME_CONFIG, SELECTIONS
|
||||
|
||||
# RUNTIME_CONFIG = runtime_cfg
|
||||
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# # All other global variables have been removed
|
||||
# # Use MemoryConfig objects instead
|
||||
|
||||
|
||||
# # 初始化:使用统一配置加载器
|
||||
# def _initialize_configuration() -> None:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration initialization
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Initialize with empty configuration for backward compatibility
|
||||
# _expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# # 模块加载时自动初始化配置
|
||||
# _initialize_configuration()
|
||||
|
||||
# # DEPRECATED: Global variables removed
|
||||
# # These variables have been eliminated in favor of dependency injection
|
||||
# # Use MemoryConfig objects instead of accessing global variables
|
||||
|
||||
|
||||
# # 公共 API:动态重新加载配置
|
||||
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration reloading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
# force_reload: Force reload flag (deprecated)
|
||||
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# warnings.warn(
|
||||
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def get_current_config_id() -> Optional[str]:
|
||||
# """
|
||||
# DEPRECATED: Legacy config ID retrieval
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Optional[str]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# def ensure_fresh_config(config_id = None) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration freshness check
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
|
||||
runtime_cfg = load_unified_config(PROJECT_ROOT)
|
||||
return runtime_cfg
|
||||
except Exception as e:
|
||||
# Fallback: 直接读取 runtime.json
|
||||
runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json")
|
||||
try:
|
||||
with open(runtime_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e2:
|
||||
pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置")
|
||||
return {"selections": {}}
|
||||
|
||||
|
||||
# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# 保留此函数仅为向后兼容
|
||||
def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从数据库加载配置(基于 dbrun.json 中的 config_id)
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json
|
||||
即可获得包含数据库配置的完整配置。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典
|
||||
"""
|
||||
try:
|
||||
# 直接使用统一配置加载器
|
||||
return _load_from_runtime_json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)
|
||||
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
"""
|
||||
将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
这是路径 A(runtime.json)和路径 B(数据库)的汇合点,
|
||||
无论配置来自哪里,都通过这个函数统一暴露为常量。
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
"""
|
||||
global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG
|
||||
global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME
|
||||
global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY
|
||||
global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES
|
||||
global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID
|
||||
global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE
|
||||
|
||||
RUNTIME_CONFIG = runtime_cfg
|
||||
|
||||
# 可观测性配置
|
||||
LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False)
|
||||
AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False)
|
||||
|
||||
# 日志配置
|
||||
LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {})
|
||||
PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"])
|
||||
|
||||
# 选择配置
|
||||
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# 基础模型选择
|
||||
SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"])
|
||||
SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"])
|
||||
SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"])
|
||||
|
||||
# 分组和用户配置
|
||||
SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"])
|
||||
SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"])
|
||||
SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"])
|
||||
SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None)
|
||||
|
||||
# 专用 LLM 配置
|
||||
SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"])
|
||||
SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"])
|
||||
SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"])
|
||||
SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"])
|
||||
|
||||
# 模型 ID 配置
|
||||
SELECTED_LLM_ID = SELECTIONS.get("llm_id", None)
|
||||
SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None)
|
||||
SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# 反思配置
|
||||
REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {})
|
||||
REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False)
|
||||
REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"])
|
||||
REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"])
|
||||
REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"])
|
||||
|
||||
|
||||
# 初始化:使用统一配置加载器
|
||||
def _initialize_configuration() -> None:
|
||||
"""
|
||||
初始化配置:使用统一配置加载器
|
||||
# warnings.warn(
|
||||
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
配置加载优先级(由 overrides.py 统一处理):
|
||||
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id)
|
||||
2. 环境变量配置(.env)
|
||||
3. runtime.json 默认配置
|
||||
"""
|
||||
try:
|
||||
|
||||
# 使用统一配置加载器(已包含所有优先级处理)
|
||||
runtime_config = _load_from_runtime_json()
|
||||
|
||||
# 暴露为全局常量
|
||||
_expose_runtime_constants(runtime_config)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
pass # print(f"[definitions] × 配置初始化失败: {e}")
|
||||
# 使用空配置
|
||||
_expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# 模块加载时自动初始化配置
|
||||
_initialize_configuration()
|
||||
|
||||
|
||||
# 公共 API:动态重新加载配置
|
||||
def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool:
|
||||
"""
|
||||
动态重新加载配置(从数据库)- 使用统一配置加载器
|
||||
用于运行时切换配置,例如前端传入新的 config_id 时调用。
|
||||
|
||||
注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID(整数或字符串,会自动转换)
|
||||
force_reload: 保留参数以保持向后兼容(已移除缓存逻辑)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功重新加载配置
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
with _config_lock:
|
||||
try:
|
||||
from app.core.memory.utils.config.overrides import load_unified_config
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 导入统一配置加载器失败: {e}")
|
||||
|
||||
# 记录配置加载失败
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"error": f"Import failed: {str(e)}"}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"[definitions] 开始重新加载配置,config_id={config_id}")
|
||||
|
||||
# 使用统一配置加载器(指定 config_id)
|
||||
updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id)
|
||||
|
||||
# 检查是否成功加载
|
||||
if not updated_cfg or not updated_cfg.get('selections'):
|
||||
logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置")
|
||||
|
||||
# 记录配置加载失败
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"reason": "config not found in database"}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
# 重新暴露常量
|
||||
_expose_runtime_constants(updated_cfg)
|
||||
|
||||
logger.info("[definitions] 配置重新加载成功,已暴露常量")
|
||||
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
|
||||
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
|
||||
|
||||
# 记录成功的配置加载
|
||||
if audit_logger:
|
||||
selections = updated_cfg.get('selections', {})
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
user_id=selections.get('user_id', None),
|
||||
group_id=selections.get('group_id', None),
|
||||
success=True,
|
||||
details={
|
||||
"llm_id": selections.get('llm_id'),
|
||||
"embedding_id": selections.get('embedding_id'),
|
||||
"chunker_strategy": selections.get('chunker_strategy')
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True)
|
||||
|
||||
# 记录配置加载异常
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_current_config_id() -> Optional[str]:
|
||||
"""
|
||||
获取当前使用的 config_id
|
||||
|
||||
Returns:
|
||||
Optional[str]: 当前的 config_id,如果未设置则返回 None
|
||||
"""
|
||||
return SELECTIONS.get("config_id", None)
|
||||
|
||||
|
||||
def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool:
|
||||
"""
|
||||
确保使用最新的配置(每次写入操作前调用)
|
||||
|
||||
如果提供了 config_id,则加载该配置;
|
||||
否则从 dbrun.json 读取并加载最新配置。
|
||||
|
||||
Args:
|
||||
config_id: 可选的配置ID(整数或字符串,会自动转换)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加载配置
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with _config_lock:
|
||||
try:
|
||||
if config_id:
|
||||
# 使用指定的 config_id
|
||||
logger.debug(f"[definitions] 加载指定配置,config_id={config_id}")
|
||||
return reload_configuration_from_database(config_id)
|
||||
else:
|
||||
# 从数据库重新加载配置
|
||||
logger.debug("[definitions] 从数据库重新加载最新配置")
|
||||
memory_config = _load_from_database()
|
||||
|
||||
if not memory_config or not memory_config.get('selections'):
|
||||
logger.warning("[definitions] 未能从数据库加载配置,使用当前配置")
|
||||
return False
|
||||
|
||||
_expose_memory_constants(memory_config)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True)
|
||||
return False
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
@@ -1,609 +0,0 @@
|
||||
"""
|
||||
运行时配置覆写工具 - 统一配置加载器
|
||||
|
||||
本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。
|
||||
|
||||
配置来源优先级(从高到低):
|
||||
1. 数据库配置(PostgreSQL data_config 表)
|
||||
2. 环境变量配置(.env 文件)
|
||||
3. 默认配置(runtime.json 文件)
|
||||
|
||||
支持的配置加载方式:
|
||||
- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入)
|
||||
- 基于 group_id 的配置加载(从 dbrun.json 读取)
|
||||
- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式)
|
||||
|
||||
主要功能:
|
||||
- 从 PostgreSQL 数据库读取配置
|
||||
- 从环境变量读取配置
|
||||
- 从 runtime.json 读取默认配置
|
||||
- 按优先级覆写配置项(仅在内存中,不修改文件)
|
||||
- 支持多种配置字段:selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion
|
||||
|
||||
使用场景:
|
||||
- 应用启动时自动加载配置
|
||||
- 前端切换配置时动态重新加载
|
||||
- 多租户场景下的配置隔离
|
||||
- 内外网环境自动切换
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any, Literal
|
||||
|
||||
NetworkMode = Literal['internal', 'external']
|
||||
|
||||
|
||||
def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster):
|
||||
"""安全地设置目标字典的值(如果源字典中存在且不为 None)
|
||||
|
||||
Args:
|
||||
target: 目标字典
|
||||
target_key: 目标字典的键
|
||||
src: 源字典
|
||||
src_key: 源字典的键
|
||||
caster: 类型转换函数
|
||||
"""
|
||||
try:
|
||||
if src_key in src and src.get(src_key) is not None:
|
||||
try:
|
||||
target[target_key] = caster(src.get(src_key))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _to_bool(val: Any) -> bool:
|
||||
"""将各种类型的值转换为布尔值
|
||||
|
||||
支持的输入:
|
||||
- bool: 直接返回
|
||||
- int/float: 非零为 True
|
||||
- str: "true", "1", "on", "yes" 为 True;"false", "0", "off", "no" 为 False
|
||||
|
||||
Args:
|
||||
val: 要转换的值
|
||||
|
||||
Returns:
|
||||
bool: 转换后的布尔值
|
||||
"""
|
||||
try:
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
if isinstance(val, (int, float)):
|
||||
return bool(val)
|
||||
if isinstance(val, str):
|
||||
m = val.strip().lower()
|
||||
if m in {"true", "1", "on", "yes"}:
|
||||
return True
|
||||
if m in {"false", "0", "off", "no"}:
|
||||
return False
|
||||
return bool(val)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _make_pgsql_conn() -> Optional[object]:
|
||||
"""创建 PostgreSQL 数据库连接
|
||||
|
||||
使用环境变量配置连接参数:
|
||||
- DB_HOST: 数据库主机地址(默认 localhost)
|
||||
- DB_PORT: 数据库端口(默认 5432)
|
||||
- DB_USER: 数据库用户名
|
||||
- DB_PASSWORD: 数据库密码
|
||||
- DB_NAME: 数据库名称
|
||||
|
||||
Returns:
|
||||
Optional[object]: 数据库连接对象,失败时返回 None
|
||||
"""
|
||||
host = os.getenv("DB_HOST", "localhost")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
dbname = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
return conn
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 group_id 从数据库查询配置
|
||||
|
||||
Args:
|
||||
group_id: 组标识符
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
|
||||
"""
|
||||
conn = _make_pgsql_conn()
|
||||
if conn is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
cur = conn.cursor(cursor_factory=RealDictCursor)
|
||||
|
||||
try:
|
||||
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sql = (
|
||||
"SELECT group_id, user_id, apply_id, chunker_strategy, "
|
||||
" enable_llm_dedup_blockwise, enable_llm_disambiguation "
|
||||
"FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1"
|
||||
)
|
||||
cur.execute(sql, (group_id,))
|
||||
row = cur.fetchone()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 config_id 从数据库查询配置
|
||||
|
||||
Args:
|
||||
config_id: 配置标识符(整数或字符串,会自动转换为整数)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
|
||||
"""
|
||||
conn = _make_pgsql_conn()
|
||||
if conn is None:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
try:
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
cur = conn.cursor(cursor_factory=RealDictCursor)
|
||||
|
||||
try:
|
||||
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# config_id 在数据库中是 Integer 类型,需要转换
|
||||
try:
|
||||
config_id_int = int(config_id)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
sql = (
|
||||
"SELECT config_id, group_id, user_id, apply_id, chunker_strategy, "
|
||||
" enable_llm_dedup_blockwise, enable_llm_disambiguation, "
|
||||
" deep_retrieval, t_type_strict, t_name_strict, t_overall, state, "
|
||||
" statement_granularity, include_dialogue_context, max_context, "
|
||||
" \"offset\" AS offset, lambda_time, lambda_mem, "
|
||||
" pruning_enabled, pruning_scene, pruning_threshold, "
|
||||
" llm_id, embedding_id, rerank_id "
|
||||
"FROM data_config WHERE config_id = %s LIMIT 1"
|
||||
)
|
||||
cur.execute(sql, (config_id_int,))
|
||||
row = cur.fetchone()
|
||||
|
||||
if row:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return row if row else None
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_dbrun_group_id(project_root: str) -> Optional[str]:
|
||||
"""从 dbrun.json 读取 group_id
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: group_id,未找到时返回 None
|
||||
"""
|
||||
try:
|
||||
path = os.path.join(project_root, "dbrun.json")
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if "group_id" in data:
|
||||
return str(data.get("group_id"))
|
||||
sel = data.get("selections", {})
|
||||
if isinstance(sel, dict) and "group_id" in sel:
|
||||
return str(sel.get("group_id"))
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _load_dbrun_config_id(project_root: str) -> Optional[str]:
|
||||
"""从 dbrun.json 读取 config_id
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: config_id,未找到时返回 None
|
||||
"""
|
||||
try:
|
||||
path = os.path.join(project_root, "dbrun.json")
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if "config_id" in data:
|
||||
return str(data.get("config_id"))
|
||||
sel = data.get("selections", {})
|
||||
if isinstance(sel, dict) and "config_id" in sel:
|
||||
return str(sel.get("config_id"))
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _apply_overrides_from_db_row(
|
||||
runtime_cfg: Dict[str, Any],
|
||||
db_row: Optional[Dict[str, Any]],
|
||||
identifier: str,
|
||||
identifier_type: str = "config_id"
|
||||
) -> Dict[str, Any]:
|
||||
"""从数据库行数据覆写运行时配置(统一处理函数)
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
db_row: 数据库查询结果行
|
||||
identifier: 标识符值(group_id 或 config_id)
|
||||
identifier_type: 标识符类型("group_id" 或 "config_id")
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selections = runtime_cfg.setdefault("selections", {})
|
||||
selections[identifier_type] = identifier
|
||||
|
||||
if not db_row:
|
||||
return runtime_cfg
|
||||
|
||||
# 覆写 selections 字段
|
||||
for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state",
|
||||
"t_type_strict", "t_name_strict", "t_overall",
|
||||
"statement_granularity", "include_dialogue_context"):
|
||||
_set_if_present(selections, tk, db_row, tk, str)
|
||||
|
||||
# 特殊处理 UUID 字段,确保转换为字符串格式
|
||||
for uuid_field in ("llm_id", "embedding_id", "rerank_id"):
|
||||
if uuid_field in db_row and db_row.get(uuid_field) is not None:
|
||||
try:
|
||||
value = db_row.get(uuid_field)
|
||||
# 如果是 UUID 对象,转换为字符串(带连字符的标准格式)
|
||||
if hasattr(value, 'hex'):
|
||||
selections[uuid_field] = str(value)
|
||||
else:
|
||||
selections[uuid_field] = str(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 覆写 statement_extraction 字段
|
||||
stmt = runtime_cfg.setdefault("statement_extraction", {})
|
||||
_set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int)
|
||||
_set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool)
|
||||
_set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int)
|
||||
|
||||
# 覆写 deduplication 字段
|
||||
dedup = runtime_cfg.setdefault("deduplication", {})
|
||||
for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"):
|
||||
_set_if_present(dedup, tk, db_row, tk, _to_bool)
|
||||
_set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool)
|
||||
|
||||
# 覆写 forgetting_engine 字段
|
||||
forgetting = runtime_cfg.setdefault("forgetting_engine", {})
|
||||
_set_if_present(forgetting, "offset", db_row, "offset", float)
|
||||
_set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float)
|
||||
_set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float)
|
||||
|
||||
# 覆写 pruning 字段
|
||||
pruning = runtime_cfg.setdefault("pruning", {})
|
||||
_set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool)
|
||||
_set_if_present(pruning, "scene", db_row, "pruning_scene", str)
|
||||
|
||||
# 阈值需要转为 float,且限制在 [0.0, 0.9]
|
||||
try:
|
||||
if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None:
|
||||
thr = float(db_row.get("pruning_threshold"))
|
||||
thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9]
|
||||
pruning["threshold"] = thr
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return runtime_cfg
|
||||
except Exception:
|
||||
pass
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""基于 group_id 从数据库覆写运行时配置
|
||||
|
||||
工作流程:
|
||||
1. 从 dbrun.json 读取 group_id
|
||||
2. 根据 group_id 查询数据库配置
|
||||
3. 覆写运行时配置(仅在内存中)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selected_gid = _load_dbrun_group_id(project_root)
|
||||
if not selected_gid:
|
||||
return runtime_cfg
|
||||
|
||||
db_row = _fetch_db_config_by_group_id(selected_gid)
|
||||
if not db_row:
|
||||
# 如果数据库中没有配置,仍然设置 group_id
|
||||
runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid
|
||||
return runtime_cfg
|
||||
|
||||
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id")
|
||||
except Exception:
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取)
|
||||
|
||||
工作流程:
|
||||
1. 从 dbrun.json 读取 config_id
|
||||
2. 根据 config_id 查询数据库配置
|
||||
3. 覆写运行时配置(仅在内存中)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selected_cid = _load_dbrun_config_id(project_root)
|
||||
if not selected_cid:
|
||||
return runtime_cfg
|
||||
|
||||
db_row = _fetch_db_config_by_config_id(selected_cid)
|
||||
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
|
||||
except Exception:
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_with_config_id(
|
||||
project_root: str,
|
||||
runtime_cfg: Dict[str, Any],
|
||||
config_id: str
|
||||
) -> tuple[Dict[str, Any], bool]:
|
||||
"""使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json)
|
||||
|
||||
用于前端动态切换配置的场景。
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
config_id: 配置标识符
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载)
|
||||
"""
|
||||
try:
|
||||
selected_cid = str(config_id).strip()
|
||||
if not selected_cid:
|
||||
return runtime_cfg, False
|
||||
|
||||
db_row = _fetch_db_config_by_config_id(selected_cid)
|
||||
if db_row is None:
|
||||
return runtime_cfg, False
|
||||
|
||||
updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
|
||||
return updated_cfg, True
|
||||
except Exception:
|
||||
pass
|
||||
return runtime_cfg, False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 以下函数已注释:不再需要网络模式自动检测功能
|
||||
# ============================================================================
|
||||
|
||||
# def get_server_ip() -> str:
|
||||
# """
|
||||
# 获取当前服务器的IP地址
|
||||
#
|
||||
# Returns:
|
||||
# 服务器IP地址字符串
|
||||
# """
|
||||
# try:
|
||||
# # 方式1:从环境变量获取(优先)
|
||||
# server_ip = os.getenv('SERVER_IP')
|
||||
# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']:
|
||||
# return server_ip
|
||||
#
|
||||
# # 方式2:通过socket获取
|
||||
# hostname = socket.gethostname()
|
||||
# ip_address = socket.gethostbyname(hostname)
|
||||
#
|
||||
# # 如果是本地回环地址,尝试获取真实IP
|
||||
# if ip_address.startswith('127.'):
|
||||
# # 尝试连接外部地址来获取本机IP
|
||||
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# try:
|
||||
# s.connect(('8.8.8.8', 80))
|
||||
# ip_address = s.getsockname()[0]
|
||||
# finally:
|
||||
# s.close()
|
||||
#
|
||||
# return ip_address
|
||||
# except Exception as e:
|
||||
# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1")
|
||||
# return '127.0.0.1'
|
||||
|
||||
|
||||
# def auto_detect_network_mode() -> NetworkMode:
|
||||
# """
|
||||
# 自动检测网络模式(基于服务器IP)
|
||||
#
|
||||
# 规则:
|
||||
# - 如果服务器IP在内网IP列表中 → internal(内网)
|
||||
# - 其他IP → external(外网)
|
||||
#
|
||||
# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表(逗号分隔)
|
||||
#
|
||||
# Returns:
|
||||
# 'internal' 或 'external'
|
||||
# """
|
||||
# server_ip = get_server_ip()
|
||||
#
|
||||
# # 从环境变量获取内网IP列表(支持多个IP,逗号分隔)
|
||||
# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55')
|
||||
# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')]
|
||||
#
|
||||
# # 判断当前IP是否在内网IP列表中
|
||||
# if server_ip in internal_ips:
|
||||
# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于内网,使用 INTERNAL 配置")
|
||||
# return 'internal'
|
||||
# else:
|
||||
# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置")
|
||||
# return 'external'
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 环境变量覆写功能已废弃 - 不再使用
|
||||
# ============================================================================
|
||||
# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]:
|
||||
# """
|
||||
# 从环境变量覆写配置(已废弃)
|
||||
# """
|
||||
# return runtime_cfg
|
||||
|
||||
|
||||
def load_unified_config(
|
||||
project_root: str,
|
||||
config_id: Optional[int | str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
network_mode: NetworkMode = None,
|
||||
env_override_models: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一配置加载器 - 按优先级加载配置
|
||||
|
||||
配置加载优先级:
|
||||
1. PG数据库配置(最高优先级,通过 dbrun.json 中的 config_id 读取)
|
||||
2. runtime.json 默认配置(最低优先级)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
config_id: 配置ID(整数或字符串,可选,优先从 dbrun.json 读取)
|
||||
group_id: 组ID(可选)
|
||||
network_mode: 已废弃,保留参数仅为向后兼容
|
||||
env_override_models: 已废弃,保留参数仅为向后兼容
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 最终的运行时配置
|
||||
"""
|
||||
try:
|
||||
# 步骤 1: 加载 runtime.json 作为基础配置
|
||||
runtime_config_path = os.path.join(project_root, "runtime.json")
|
||||
try:
|
||||
with open(runtime_config_path, "r", encoding="utf-8") as f:
|
||||
runtime_cfg = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
runtime_cfg = {"selections": {}}
|
||||
|
||||
# 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级)
|
||||
if config_id:
|
||||
# 优先使用传入的 config_id
|
||||
db_row = _fetch_db_config_by_config_id(config_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id")
|
||||
pass
|
||||
elif group_id:
|
||||
# 其次使用 group_id
|
||||
db_row = _fetch_db_config_by_group_id(group_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id")
|
||||
pass
|
||||
else:
|
||||
# 尝试从 dbrun.json 读取
|
||||
dbrun_config_id = _load_dbrun_config_id(project_root)
|
||||
if dbrun_config_id:
|
||||
db_row = _fetch_db_config_by_config_id(dbrun_config_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id")
|
||||
pass
|
||||
else:
|
||||
dbrun_group_id = _load_dbrun_group_id(project_root)
|
||||
if dbrun_group_id:
|
||||
db_row = _fetch_db_config_by_group_id(dbrun_group_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id")
|
||||
pass
|
||||
return runtime_cfg
|
||||
|
||||
except Exception:
|
||||
return {"selections": {}}
|
||||
|
||||
|
||||
# 向后兼容的别名
|
||||
apply_runtime_overrides = apply_runtime_overrides_by_config
|
||||
11
api/app/core/memory/utils/embedder/__init__.py
Normal file
11
api/app/core/memory/utils/embedder/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Embedder utilities module."""
|
||||
|
||||
from app.core.memory.utils.embedder.embedder_utils import (
|
||||
get_embedder_client,
|
||||
get_embedder_client_from_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_embedder_client",
|
||||
"get_embedder_client_from_config",
|
||||
]
|
||||
83
api/app/core/memory/utils/embedder/embedder_utils.py
Normal file
83
api/app/core/memory/utils/embedder/embedder_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Embedder Client Utilities
|
||||
|
||||
This module provides centralized functions for creating embedder clients.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
def get_embedder_client_from_config(memory_config: "MemoryConfig") -> OpenAIEmbedderClient:
|
||||
"""
|
||||
Get embedder client from MemoryConfig object.
|
||||
|
||||
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
|
||||
This ensures proper configuration management and multi-tenant support.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient: Initialized embedder client
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding model ID is not configured or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> embedder_client = get_embedder_client_from_config(memory_config)
|
||||
"""
|
||||
if not memory_config.embedding_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no embedding model configured"
|
||||
)
|
||||
return get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
|
||||
def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
|
||||
"""
|
||||
Get embedder client by model ID.
|
||||
|
||||
**LEGACY/TEST METHOD**: Use this function only for:
|
||||
- Test/evaluation code where you have a model ID directly
|
||||
- Legacy code that hasn't been migrated to MemoryConfig yet
|
||||
|
||||
For production code with MemoryConfig, use get_embedder_client_from_config() instead.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID (required)
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient: Initialized embedder client
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding_id is not provided or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> # For tests/evaluations only
|
||||
>>> embedder_client = get_embedder_client("model-uuid-string")
|
||||
"""
|
||||
if not embedding_id:
|
||||
raise ValueError("Embedding ID is required but was not provided")
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
return embedder_client
|
||||
except Exception as e:
|
||||
model_name = embedder_config_dict.get('model_name', 'unknown')
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder client for model '{model_name}': {str(e)}"
|
||||
) from e
|
||||
@@ -1,77 +1,237 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
async def handle_response(response: type[BaseModel]) -> dict:
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
def get_llm_client(llm_id: str | None = None):
|
||||
llm_id = llm_id or config_defs.SELECTED_LLM_ID
|
||||
|
||||
# Validate LLM ID exists before attempting to get config
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
# Re-raise with clear error message about invalid LLM ID
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
# 移除调试打印,避免污染终端输出
|
||||
# print(model_config)
|
||||
llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
# print(llm.dict())
|
||||
return llm_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str | None = None):
|
||||
class MemoryClientFactory:
|
||||
"""
|
||||
Get an LLM client configured for reranking.
|
||||
Factory for creating LLM, embedder, and reranker clients.
|
||||
|
||||
Initialize once with db session, then call methods without passing db each time.
|
||||
|
||||
Example:
|
||||
>>> factory = MemoryClientFactory(db)
|
||||
>>> llm_client = factory.get_llm_client(model_id)
|
||||
>>> embedder_client = factory.get_embedder_client(embedding_id)
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
self._config_service = MemoryConfigService(db)
|
||||
|
||||
def get_llm_client(self, llm_id: str) -> OpenAIClient:
|
||||
"""Get LLM client by model ID."""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required")
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_embedder_client(self, embedding_id: str):
|
||||
"""Get embedder client by model ID."""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
if not embedding_id:
|
||||
raise ValueError("Embedding ID is required")
|
||||
|
||||
try:
|
||||
embedder_config = self._config_service.get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIEmbedderClient(
|
||||
RedBearModelConfig(
|
||||
model_name=embedder_config.get("model_name"),
|
||||
provider=embedder_config.get("provider"),
|
||||
api_key=embedder_config.get("api_key"),
|
||||
base_url=embedder_config.get("base_url")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = embedder_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
|
||||
"""Get reranker client by model ID."""
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required")
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""Get LLM client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing llm_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the LLM model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no LLM model configured
|
||||
"""
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no LLM model configured"
|
||||
)
|
||||
return self.get_llm_client(str(memory_config.llm_model_id))
|
||||
|
||||
def get_embedder_client_from_config(self, memory_config: "MemoryConfig"):
|
||||
"""Get embedder client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing embedding_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient configured for the embedding model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no embedding model configured
|
||||
"""
|
||||
if not memory_config.embedding_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no embedding model configured"
|
||||
)
|
||||
return self.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""Get reranker client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing rerank_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no rerank model configured
|
||||
"""
|
||||
if not memory_config.rerank_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no rerank model configured"
|
||||
)
|
||||
return self.get_reranker_client(str(memory_config.rerank_model_id))
|
||||
|
||||
|
||||
# Legacy functions for backward compatibility
|
||||
def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient:
|
||||
"""Get LLM client from MemoryConfig object.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_llm_client_from_config method directly.
|
||||
|
||||
Args:
|
||||
rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID.
|
||||
memory_config: Configuration containing llm_model_id
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized client for the reranker model
|
||||
OpenAIClient configured for the LLM model
|
||||
|
||||
Raises:
|
||||
ValueError: If rerank_id is invalid or client initialization fails
|
||||
ValueError: If memory_config has no LLM model configured
|
||||
"""
|
||||
rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID
|
||||
return MemoryClientFactory(db).get_llm_client_from_config(memory_config)
|
||||
|
||||
|
||||
def get_llm_client(llm_id: str, db: Session) -> OpenAIClient:
|
||||
"""Get LLM client by model ID.
|
||||
|
||||
# Validate rerank ID exists before attempting to get config
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required but was not provided")
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead.
|
||||
|
||||
try:
|
||||
model_config = get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
# Re-raise with clear error message about invalid rerank ID
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_llm_client method directly.
|
||||
|
||||
try:
|
||||
reranker_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
return reranker_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
Args:
|
||||
llm_id: LLM model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the LLM model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_llm_client(llm_id)
|
||||
|
||||
|
||||
def get_embedder_client(embedding_id: str, db: Session):
|
||||
"""Get embedder client by model ID.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_embedder_client method directly.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient configured for the embedding model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_embedder_client(embedding_id)
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient:
|
||||
"""Get reranker client by model ID.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_reranker_client method directly.
|
||||
|
||||
Args:
|
||||
rerank_id: Reranker model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the reranker model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_reranker_client(rerank_id)
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]:
|
||||
冲突记忆列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]:
|
||||
反思结果列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
@@ -10,28 +10,29 @@
|
||||
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true"
|
||||
REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3")
|
||||
REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval")
|
||||
REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME")
|
||||
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
REFLEXION_ENABLED,
|
||||
REFLEXION_ITERATION_PERIOD,
|
||||
REFLEXION_RANGE,
|
||||
REFLEXION_BASELINE,
|
||||
)
|
||||
from app.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 并发限制(可通过环境变量覆盖)
|
||||
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))
|
||||
|
||||
@@ -5,16 +5,24 @@ This module provides functionality to analyze chunk content and generate insight
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
|
||||
|
||||
class ChunkInsight(BaseModel):
|
||||
"""Pydantic model for chunk insight."""
|
||||
insight: str = Field(..., description="对chunk内容的深度洞察分析")
|
||||
@@ -40,7 +48,7 @@ async def classify_chunk_domain(chunk: str) -> str:
|
||||
Domain name
|
||||
"""
|
||||
try:
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||
|
||||
@@ -177,7 +185,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
|
||||
]
|
||||
|
||||
# 调用LLM生成洞察
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
insight = response.content.strip()
|
||||
|
||||
@@ -5,15 +5,23 @@ This module provides functionality to summarize chunk content using LLM.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
|
||||
|
||||
class ChunkSummary(BaseModel):
|
||||
"""Pydantic model for chunk summary."""
|
||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||
@@ -59,7 +67,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
|
||||
]
|
||||
|
||||
# 调用LLM生成摘要
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
summary = response.content.strip()
|
||||
|
||||
@@ -7,14 +7,22 @@ This module provides functionality to extract meaningful tags from chunk content
|
||||
import asyncio
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
|
||||
|
||||
class ExtractedTags(BaseModel):
|
||||
"""Pydantic model for extracted tags."""
|
||||
tags: List[str] = Field(..., description="从文本中提取的关键标签列表")
|
||||
@@ -56,7 +64,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
# 为每个chunk单独提取标签,然后统计频率
|
||||
all_tags = []
|
||||
@@ -151,7 +159,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
||||
]
|
||||
|
||||
# 调用LLM提取人物形象
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=ExtractedPersona
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
"""
|
||||
Validators for file upload system.
|
||||
Validators package for various validation utilities.
|
||||
"""
|
||||
from app.core.validators.file_validator import FileValidator, ValidationResult
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_llm_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
|
||||
__all__ = ["FileValidator", "ValidationResult"]
|
||||
__all__ = [
|
||||
# File validators
|
||||
"FileValidator",
|
||||
"ValidationResult",
|
||||
# Memory config validators
|
||||
"validate_model_exists_and_active",
|
||||
"validate_and_resolve_model_id",
|
||||
"validate_embedding_model",
|
||||
"validate_llm_model",
|
||||
]
|
||||
|
||||
250
api/app/core/validators/memory_config_validators.py
Normal file
250
api/app/core/validators/memory_config_validators.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Validators
|
||||
|
||||
This module provides validation functions for memory configuration models.
|
||||
|
||||
Functions:
|
||||
validate_model_exists_and_active: Validate model exists and is active
|
||||
validate_and_resolve_model_id: Validate and resolve model ID with DB lookup
|
||||
validate_embedding_model: Validate embedding model availability
|
||||
validate_llm_model: Validate LLM model availability
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.logging_config import get_config_logger
|
||||
from app.schemas.memory_config_schema import (
|
||||
InvalidConfigError,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_config_logger()
|
||||
|
||||
|
||||
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
|
||||
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||
"""Parse model ID from string or UUID."""
|
||||
if model_id is None:
|
||||
return None
|
||||
if isinstance(model_id, UUID):
|
||||
return model_id
|
||||
if isinstance(model_id, str):
|
||||
if not model_id.strip():
|
||||
return None
|
||||
try:
|
||||
return UUID(model_id.strip())
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid UUID format for {model_type} model ID: '{model_id}'",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for {model_type} model ID: expected str or UUID, got {type(model_id).__name__}",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
def validate_model_exists_and_active(
|
||||
model_id: UUID,
|
||||
model_type: str,
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate that a model exists and is active.
|
||||
|
||||
Args:
|
||||
model_id: Model UUID to validate
|
||||
model_type: Type of model ("llm", "embedding", "rerank")
|
||||
db: Database session
|
||||
tenant_id: Optional tenant ID for filtering
|
||||
config_id: Optional configuration ID for error context
|
||||
workspace_id: Optional workspace ID for error context
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_active)
|
||||
|
||||
Raises:
|
||||
ModelNotFoundError: If model does not exist
|
||||
ModelInactiveError: If model exists but is inactive
|
||||
"""
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not model:
|
||||
logger.warning(
|
||||
"Model not found",
|
||||
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
raise ModelNotFoundError(
|
||||
model_id=model_id,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} not found"
|
||||
)
|
||||
|
||||
if not model.is_active:
|
||||
logger.warning(
|
||||
"Model inactive",
|
||||
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
raise ModelInactiveError(
|
||||
model_id=model_id,
|
||||
model_name=model.name,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} ({model.name}) is inactive"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Model validation successful",
|
||||
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
return model.name, model.is_active
|
||||
|
||||
except (ModelNotFoundError, ModelInactiveError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Model validation failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def validate_and_resolve_model_id(
|
||||
model_id_str: Union[str, UUID, None],
|
||||
model_type: str,
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
required: bool = False,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[Optional[UUID], Optional[str]]:
|
||||
"""Validate and resolve a model ID, checking existence and active status.
|
||||
|
||||
Returns:
|
||||
Tuple of (validated_uuid, model_name) or (None, None) if not required and empty
|
||||
"""
|
||||
if model_id_str is None or (isinstance(model_id_str, str) and not model_id_str.strip()):
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model ID is required",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id_str,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return None, None
|
||||
|
||||
model_uuid = _parse_model_id(model_id_str, model_type, config_id, workspace_id)
|
||||
if model_uuid is None:
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model ID is required",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id_str,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return None, None
|
||||
|
||||
model_name, _ = validate_model_exists_and_active(
|
||||
model_uuid, model_type, db, tenant_id, config_id, workspace_id
|
||||
)
|
||||
return model_uuid, model_name
|
||||
|
||||
|
||||
def validate_embedding_model(
|
||||
config_id: int,
|
||||
embedding_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> UUID:
|
||||
"""Validate that embedding model is available and return its UUID.
|
||||
|
||||
Raises:
|
||||
InvalidConfigError: If embedding_id is not provided or invalid
|
||||
ModelNotFoundError: If embedding model does not exist
|
||||
ModelInactiveError: If embedding model is inactive
|
||||
"""
|
||||
if embedding_id is None or (isinstance(embedding_id, str) and not embedding_id.strip()):
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no embedding model configured",
|
||||
field_name="embedding_model_id",
|
||||
invalid_value=embedding_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
embedding_uuid, _ = validate_and_resolve_model_id(
|
||||
embedding_id, "embedding", db, tenant_id, required=True,
|
||||
config_id=config_id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if embedding_uuid is None:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no embedding model configured",
|
||||
field_name="embedding_model_id",
|
||||
invalid_value=embedding_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return embedding_uuid
|
||||
|
||||
|
||||
def validate_llm_model(
|
||||
config_id: int,
|
||||
llm_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> UUID:
|
||||
"""Validate that LLM model is available and return its UUID.
|
||||
|
||||
Raises:
|
||||
InvalidConfigError: If llm_id is not provided or invalid
|
||||
ModelNotFoundError: If LLM model does not exist
|
||||
ModelInactiveError: If LLM model is inactive
|
||||
"""
|
||||
if llm_id is None or (isinstance(llm_id, str) and not llm_id.strip()):
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no LLM model configured",
|
||||
field_name="llm_model_id",
|
||||
invalid_value=llm_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
llm_uuid, _ = validate_and_resolve_model_id(
|
||||
llm_id, "llm", db, tenant_id, required=True,
|
||||
config_id=config_id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if llm_uuid is None:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no LLM model configured",
|
||||
field_name="llm_model_id",
|
||||
invalid_value=llm_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return llm_uuid
|
||||
39
api/app/models/memory_config_model.py
Normal file
39
api/app/models/memory_config_model.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Model - Backward Compatibility
|
||||
|
||||
This module provides backward compatibility for imports.
|
||||
All classes have been moved to app.schemas.memory_config_schema.
|
||||
|
||||
DEPRECATED: Import from app.schemas.memory_config_schema instead.
|
||||
"""
|
||||
|
||||
# Re-export for backward compatibility
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
MemoryConfigValidation,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
ModelValidation,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceValidation,
|
||||
validate_memory_config_data,
|
||||
validate_model_data,
|
||||
validate_workspace_data,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfigurationError",
|
||||
"InvalidConfigError",
|
||||
"MemoryConfig",
|
||||
"MemoryConfigValidation",
|
||||
"ModelInactiveError",
|
||||
"ModelNotFoundError",
|
||||
"ModelValidation",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspaceValidation",
|
||||
"validate_memory_config_data",
|
||||
"validate_model_data",
|
||||
"validate_workspace_data",
|
||||
]
|
||||
@@ -8,22 +8,25 @@ Classes:
|
||||
DataConfigRepository: 数据配置仓储类,提供CRUD操作
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
# 获取配置专用日志器
|
||||
config_logger = get_config_logger()
|
||||
|
||||
TABLE_NAME = "data_config"
|
||||
class DataConfigRepository:
|
||||
@@ -525,7 +528,129 @@ class DataConfigRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
|
||||
"""Get data config and its associated workspace information
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
config_id: Configuration ID
|
||||
|
||||
Returns:
|
||||
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when config exists but workspace doesn't
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Log configuration loading start
|
||||
config_logger.info(
|
||||
"Loading configuration with workspace",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id
|
||||
}
|
||||
)
|
||||
|
||||
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
|
||||
|
||||
try:
|
||||
# Use join query to get both config and workspace
|
||||
result = db.query(DataConfig, Workspace).join(
|
||||
Workspace, DataConfig.workspace_id == Workspace.id
|
||||
).filter(DataConfig.config_id == config_id).first()
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not result:
|
||||
# Check if config exists but workspace is missing
|
||||
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if config_only:
|
||||
if config_only.workspace_id is None:
|
||||
config_logger.error(
|
||||
"Configuration has no associated workspace ID",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"workspace_id": None,
|
||||
"load_result": "no_workspace_id",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.error(f"Data config {config_id} has no associated workspace ID")
|
||||
raise ValueError(f"Configuration {config_id} has no associated workspace")
|
||||
else:
|
||||
config_logger.error(
|
||||
"Configuration references non-existent workspace",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"workspace_id": str(config_only.workspace_id),
|
||||
"load_result": "workspace_not_found",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
|
||||
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
||||
|
||||
config_logger.debug(
|
||||
"Configuration not found",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.debug(f"Data config not found: config_id={config_id}")
|
||||
return None
|
||||
|
||||
config, workspace = result
|
||||
|
||||
# Log successful configuration loading
|
||||
config_logger.info(
|
||||
"Configuration with workspace loaded successfully",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(workspace.id),
|
||||
"workspace_name": workspace.name,
|
||||
"tenant_id": str(workspace.tenant_id),
|
||||
"load_result": "success",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
|
||||
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
|
||||
except ValueError:
|
||||
# Re-raise known business exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load configuration with workspace",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"load_result": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"elapsed_ms": elapsed_ms
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
474
api/app/schemas/memory_config_schema.py
Normal file
474
api/app/schemas/memory_config_schema.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Schemas
|
||||
|
||||
This module provides schema definitions for memory configuration.
|
||||
|
||||
Classes:
|
||||
MemoryConfig: Immutable memory configuration loaded from database
|
||||
MemoryConfigValidation: Pydantic model for configuration validation
|
||||
WorkspaceValidation: Pydantic model for workspace validation
|
||||
ModelValidation: Pydantic model for model configuration validation
|
||||
ConfigurationError: Base exception for configuration-related errors
|
||||
WorkspaceNotFoundError: Raised when workspace does not exist
|
||||
ModelNotFoundError: Raised when a required model does not exist
|
||||
ModelInactiveError: Raised when a required model exists but is inactive
|
||||
InvalidConfigError: Raised when configuration validation fails
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
|
||||
# ==================== Configuration Exception Classes ====================
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""Base exception for configuration-related errors.
|
||||
|
||||
This exception includes context information to help with debugging
|
||||
and provides detailed error messages for different failure scenarios.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize configuration error with context.
|
||||
|
||||
Args:
|
||||
message: Error message describing the failure
|
||||
config_id: Optional configuration ID for context
|
||||
workspace_id: Optional workspace ID for context
|
||||
context: Optional additional context information
|
||||
"""
|
||||
self.config_id = config_id
|
||||
self.workspace_id = workspace_id
|
||||
self.context = context or {}
|
||||
|
||||
# Build detailed error message with context
|
||||
detailed_message = message
|
||||
if config_id is not None:
|
||||
detailed_message = f"Configuration {config_id}: {message}"
|
||||
if workspace_id is not None:
|
||||
detailed_message = f"{detailed_message} (workspace: {workspace_id})"
|
||||
|
||||
# Add context information if available
|
||||
if self.context:
|
||||
context_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
|
||||
detailed_message = f"{detailed_message} [Context: {context_str}]"
|
||||
|
||||
super().__init__(detailed_message)
|
||||
|
||||
|
||||
class WorkspaceNotFoundError(ConfigurationError):
|
||||
"""Raised when workspace does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
config_id: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
message = f"Workspace {workspace_id} not found in database"
|
||||
|
||||
context = {"workspace_id": str(workspace_id)}
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
class ModelNotFoundError(ConfigurationError):
|
||||
"""Raised when a required model does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Union[str, UUID],
|
||||
model_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
message = f"{model_type.title()} model {model_id} not found in database"
|
||||
|
||||
context = {
|
||||
"model_id": str(model_id),
|
||||
"model_type": model_type,
|
||||
"failure_type": "not_found",
|
||||
}
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
class ModelInactiveError(ConfigurationError):
|
||||
"""Raised when a required model exists but is inactive."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Union[str, UUID],
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
message = f"{model_type.title()} model {model_id} ({model_name}) is inactive"
|
||||
|
||||
context = {
|
||||
"model_id": str(model_id),
|
||||
"model_name": model_name,
|
||||
"model_type": model_type,
|
||||
"failure_type": "inactive",
|
||||
}
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
class InvalidConfigError(ConfigurationError):
|
||||
"""Raised when configuration validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
field_name: Optional[str] = None,
|
||||
invalid_value: Optional[Any] = None,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
):
|
||||
context = {}
|
||||
if field_name is not None:
|
||||
context["field_name"] = field_name
|
||||
if invalid_value is not None:
|
||||
context["invalid_value"] = str(invalid_value)
|
||||
context["invalid_value_type"] = type(invalid_value).__name__
|
||||
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
# ==================== Pydantic Validation Models ====================
|
||||
|
||||
|
||||
class MemoryConfigValidation(BaseModel):
|
||||
"""Pydantic model for validating memory configuration data from database."""
|
||||
|
||||
config_id: int = Field(..., gt=0, description="Configuration ID must be positive")
|
||||
config_name: str = Field(..., min_length=1, max_length=255)
|
||||
workspace_id: UUID = Field(..., description="Workspace UUID")
|
||||
workspace_name: str = Field(..., min_length=1, max_length=255)
|
||||
tenant_id: UUID = Field(..., description="Tenant UUID")
|
||||
|
||||
embedding_model_id: UUID = Field(..., description="Embedding model UUID (required)")
|
||||
embedding_model_name: str = Field(..., min_length=1, max_length=255)
|
||||
llm_model_id: UUID = Field(..., description="LLM model UUID (required)")
|
||||
llm_model_name: str = Field(..., min_length=1, max_length=255)
|
||||
rerank_model_id: Optional[UUID] = Field(None, description="Rerank model UUID (optional)")
|
||||
rerank_model_name: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
storage_type: str = Field(..., min_length=1, max_length=50)
|
||||
|
||||
chunker_strategy: str = Field(default="RecursiveChunker", min_length=1, max_length=100)
|
||||
reflexion_enabled: bool = Field(default=False)
|
||||
reflexion_iteration_period: int = Field(default=3, ge=1, le=100)
|
||||
reflexion_range: Literal["retrieval", "all"] = Field(default="retrieval")
|
||||
reflexion_baseline: Literal["time", "fact", "time_and_fact"] = Field(default="time")
|
||||
|
||||
llm_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
config_version: str = Field(default="2.0", min_length=1, max_length=10)
|
||||
|
||||
@field_validator("config_name", "workspace_name", "embedding_model_name", "llm_model_name")
|
||||
@classmethod
|
||||
def validate_non_empty_strings(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Field cannot be empty or whitespace-only")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("storage_type")
|
||||
@classmethod
|
||||
def validate_storage_type(cls, v):
|
||||
valid_types = ["neo4j", "elasticsearch", "qdrant", "milvus", "chroma"]
|
||||
if v.lower() not in valid_types:
|
||||
raise ValueError(f"Storage type must be one of: {valid_types}")
|
||||
return v.lower()
|
||||
|
||||
@field_validator("llm_params", "embedding_params")
|
||||
@classmethod
|
||||
def validate_model_params(cls, v):
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("Model parameters must be a dictionary")
|
||||
reserved_keys = ["model_id", "model_name", "api_key", "base_url"]
|
||||
for key in v.keys():
|
||||
if key in reserved_keys:
|
||||
raise ValueError(f"Model parameters cannot contain reserved parameter '{key}'")
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
|
||||
|
||||
class WorkspaceValidation(BaseModel):
|
||||
"""Pydantic model for validating workspace data from database."""
|
||||
|
||||
id: UUID = Field(..., description="Workspace UUID")
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
tenant_id: UUID = Field(..., description="Tenant UUID")
|
||||
storage_type: Optional[str] = Field(None, max_length=50)
|
||||
llm: Optional[str] = Field(None)
|
||||
embedding: Optional[str] = Field(None)
|
||||
rerank: Optional[str] = Field(None)
|
||||
is_active: bool = Field(default=True)
|
||||
|
||||
@field_validator("llm", "embedding", "rerank")
|
||||
@classmethod
|
||||
def validate_model_ids(cls, v):
|
||||
if v is None or v == "":
|
||||
return None
|
||||
try:
|
||||
UUID(v.strip())
|
||||
except ValueError:
|
||||
raise ValueError("Model ID must be a valid UUID string")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("is_active")
|
||||
@classmethod
|
||||
def validate_active_status(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Workspace must be active for configuration loading")
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
|
||||
|
||||
class ModelValidation(BaseModel):
|
||||
"""Pydantic model for validating model configuration data."""
|
||||
|
||||
id: UUID = Field(..., description="Model UUID")
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
type: str = Field(..., description="Model type (llm, embedding, rerank)")
|
||||
tenant_id: UUID = Field(..., description="Tenant UUID")
|
||||
is_active: bool = Field(..., description="Whether model is active")
|
||||
is_public: bool = Field(default=False)
|
||||
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
valid_types = ["llm", "embedding", "rerank"]
|
||||
if v.lower() not in valid_types:
|
||||
raise ValueError(f"Model type must be one of: {valid_types}")
|
||||
return v.lower()
|
||||
|
||||
@field_validator("is_active")
|
||||
@classmethod
|
||||
def validate_active_status(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model must be active for configuration use")
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
|
||||
|
||||
# ==================== Validation Helper Functions ====================
|
||||
|
||||
|
||||
def validate_memory_config_data(
|
||||
config_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
) -> MemoryConfigValidation:
|
||||
"""Validate memory configuration data using Pydantic model."""
|
||||
try:
|
||||
return MemoryConfigValidation(**config_data)
|
||||
except ValidationError as e:
|
||||
error_messages = []
|
||||
for error in e.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_messages.append(f"Field '{field_path}': {error['msg']}")
|
||||
|
||||
detailed_message = "Configuration validation failed:\n" + "\n".join(
|
||||
f" - {msg}" for msg in error_messages
|
||||
)
|
||||
|
||||
first_error = e.errors()[0] if e.errors() else {}
|
||||
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
|
||||
|
||||
raise InvalidConfigError(
|
||||
detailed_message,
|
||||
field_name=first_field or None,
|
||||
invalid_value=first_error.get("input"),
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
|
||||
def validate_workspace_data(
|
||||
workspace_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
) -> WorkspaceValidation:
|
||||
"""Validate workspace data using Pydantic model."""
|
||||
try:
|
||||
return WorkspaceValidation(**workspace_data)
|
||||
except ValidationError as e:
|
||||
error_messages = []
|
||||
for error in e.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_messages.append(f"Field '{field_path}': {error['msg']}")
|
||||
|
||||
detailed_message = "Workspace validation failed:\n" + "\n".join(
|
||||
f" - {msg}" for msg in error_messages
|
||||
)
|
||||
|
||||
first_error = e.errors()[0] if e.errors() else {}
|
||||
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
|
||||
workspace_id = workspace_data.get("id") if isinstance(workspace_data, dict) else None
|
||||
|
||||
raise InvalidConfigError(
|
||||
detailed_message,
|
||||
field_name=first_field or None,
|
||||
invalid_value=first_error.get("input"),
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
|
||||
def validate_model_data(
|
||||
model_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
) -> ModelValidation:
|
||||
"""Validate model data using Pydantic model."""
|
||||
try:
|
||||
return ModelValidation(**model_data)
|
||||
except ValidationError as e:
|
||||
error_messages = []
|
||||
for error in e.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_messages.append(f"Field '{field_path}': {error['msg']}")
|
||||
|
||||
detailed_message = "Model validation failed:\n" + "\n".join(
|
||||
f" - {msg}" for msg in error_messages
|
||||
)
|
||||
|
||||
first_error = e.errors()[0] if e.errors() else {}
|
||||
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
|
||||
|
||||
raise InvalidConfigError(
|
||||
detailed_message,
|
||||
field_name=first_field or None,
|
||||
invalid_value=first_error.get("input"),
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Immutable Configuration Data Structure ====================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryConfig:
|
||||
"""Immutable memory configuration loaded from database."""
|
||||
|
||||
config_id: int
|
||||
config_name: str
|
||||
workspace_id: UUID
|
||||
workspace_name: str
|
||||
tenant_id: UUID
|
||||
|
||||
embedding_model_id: UUID
|
||||
embedding_model_name: str
|
||||
llm_model_id: UUID
|
||||
llm_model_name: str
|
||||
|
||||
storage_type: str
|
||||
|
||||
chunker_strategy: str
|
||||
reflexion_enabled: bool
|
||||
reflexion_iteration_period: int
|
||||
reflexion_range: str
|
||||
reflexion_baseline: str
|
||||
|
||||
loaded_at: datetime
|
||||
|
||||
rerank_model_id: Optional[UUID] = None
|
||||
rerank_model_name: Optional[str] = None
|
||||
|
||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||
config_version: str = "2.0"
|
||||
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise: bool = False
|
||||
enable_llm_disambiguation: bool = False
|
||||
deep_retrieval: bool = True
|
||||
t_type_strict: float = 0.8
|
||||
t_name_strict: float = 0.8
|
||||
t_overall: float = 0.8
|
||||
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity: int = 2
|
||||
include_dialogue_context: bool = False
|
||||
max_dialogue_context_chars: int = 1000
|
||||
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time: float = 0.5
|
||||
lambda_mem: float = 0.5
|
||||
offset: float = 0.0
|
||||
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled: bool = False
|
||||
pruning_scene: Optional[str] = "education"
|
||||
pruning_threshold: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.config_name or not self.config_name.strip():
|
||||
raise InvalidConfigError("Configuration name cannot be empty")
|
||||
|
||||
if not self.embedding_model_id:
|
||||
raise InvalidConfigError("Embedding model ID is required")
|
||||
|
||||
if not self.llm_model_id:
|
||||
raise InvalidConfigError("LLM model ID is required")
|
||||
|
||||
@classmethod
|
||||
def from_validated_data(
|
||||
cls, validated_config: MemoryConfigValidation, loaded_at: datetime
|
||||
) -> "MemoryConfig":
|
||||
"""Create MemoryConfig from validated Pydantic data."""
|
||||
return cls(
|
||||
config_id=validated_config.config_id,
|
||||
config_name=validated_config.config_name,
|
||||
workspace_id=validated_config.workspace_id,
|
||||
workspace_name=validated_config.workspace_name,
|
||||
tenant_id=validated_config.tenant_id,
|
||||
embedding_model_id=validated_config.embedding_model_id,
|
||||
embedding_model_name=validated_config.embedding_model_name,
|
||||
storage_type=validated_config.storage_type,
|
||||
chunker_strategy=validated_config.chunker_strategy,
|
||||
reflexion_enabled=validated_config.reflexion_enabled,
|
||||
reflexion_iteration_period=validated_config.reflexion_iteration_period,
|
||||
reflexion_range=validated_config.reflexion_range,
|
||||
reflexion_baseline=validated_config.reflexion_baseline,
|
||||
loaded_at=loaded_at,
|
||||
llm_model_id=validated_config.llm_model_id,
|
||||
llm_model_name=validated_config.llm_model_name,
|
||||
rerank_model_id=validated_config.rerank_model_id,
|
||||
rerank_model_name=validated_config.rerank_model_name,
|
||||
llm_params=validated_config.llm_params,
|
||||
embedding_params=validated_config.embedding_params,
|
||||
config_version=validated_config.config_version,
|
||||
)
|
||||
|
||||
def get_model_summary(self) -> Dict[str, Optional[str]]:
|
||||
"""Get a summary of configured models."""
|
||||
return {
|
||||
"llm": self.llm_model_name,
|
||||
"embedding": self.embedding_model_name,
|
||||
"rerank": self.rerank_model_name,
|
||||
}
|
||||
|
||||
def is_model_configured(self, model_type: str) -> bool:
|
||||
"""Check if a specific model type is configured."""
|
||||
if model_type == "llm":
|
||||
return True
|
||||
elif model_type == "embedding":
|
||||
return True
|
||||
elif model_type == "rerank":
|
||||
return self.rerank_model_id is not None
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
@@ -3,26 +3,26 @@
|
||||
|
||||
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.models import AgentConfig, ModelConfig, ModelApiKey
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.services.langchain_tool_server import Search
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"""
|
||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||
|
||||
@@ -713,9 +719,9 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当指定的会话不存在时
|
||||
"""
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.models import Conversation as ConversationModel
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
|
||||
|
||||
@@ -7,14 +7,15 @@ Classes:
|
||||
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import statistics
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
import statistics
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.neo4j.emotion_repository import EmotionRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_business_logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
|
||||
async def generate_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
config_id: Optional[int] = None
|
||||
db: Session,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成个性化情绪建议
|
||||
|
||||
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
config_id: 配置ID(可选,用于从数据库加载LLM配置)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 包含个性化建议的响应:
|
||||
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
|
||||
- suggestions: 建议列表(3-5条)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}")
|
||||
|
||||
# 1. 如果提供了 config_id,从数据库加载配置
|
||||
if config_id is not None:
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
|
||||
# 1. 从 end_user_id 获取关联的 memory_config_id
|
||||
llm_client = None
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is not None:
|
||||
from app.services.memory_config_service import (
|
||||
MemoryConfigService,
|
||||
)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(config_id),
|
||||
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
|
||||
|
||||
# 2. 获取情绪健康数据
|
||||
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
|
||||
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
|
||||
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
|
||||
|
||||
# 7. 调用LLM生成建议(使用配置中的LLM)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
llm_client = get_llm_client()
|
||||
if llm_client is None:
|
||||
# 无法获取配置时,抛出错误而不是使用默认配置
|
||||
raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config")
|
||||
|
||||
# 将 prompt 转换为 messages 格式
|
||||
messages = [
|
||||
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
|
||||
Returns:
|
||||
str: LLM prompt
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_suggestions_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_suggestions_prompt(
|
||||
health_data=health_data,
|
||||
|
||||
@@ -9,10 +9,12 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.data_config_model import DataConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +52,9 @@ class EmotionExtractionService:
|
||||
"""
|
||||
if self.llm_client is None or model_id:
|
||||
effective_model_id = model_id or self.llm_id
|
||||
self.llm_client = get_llm_client(effective_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(effective_model_id)
|
||||
return self.llm_client
|
||||
|
||||
async def extract_emotion(
|
||||
@@ -142,7 +146,9 @@ class EmotionExtractionService:
|
||||
Returns:
|
||||
Formatted prompt string for LLM
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_extraction_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_extraction_prompt(
|
||||
statement=statement,
|
||||
|
||||
@@ -4,50 +4,48 @@ Memory Agent Service
|
||||
Handles business logic for memory agent operations including read/write services,
|
||||
health checks, and message type classification.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.db import get_db
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Initialize Neo4j connector for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
|
||||
class MemoryAgentService:
|
||||
"""Service for memory agent operations"""
|
||||
@@ -257,14 +255,17 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
|
||||
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Write operation result status
|
||||
@@ -272,24 +273,40 @@ class MemoryAgentService:
|
||||
Raises:
|
||||
ValueError: If config loading fails or write operation fails
|
||||
"""
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 如果 config_id 为 None,使用默认值 "17"
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Load configuration from database only
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
@@ -300,20 +317,43 @@ class MemoryAgentService:
|
||||
async with client.session("data_flow") as session:
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
tools = await load_mcp_tools(session)
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
|
||||
logger.debug("Write graph created successfully")
|
||||
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": message, "config_id": config_id},
|
||||
{"messages": message, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.error(f"Write workflow failed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_details
|
||||
)
|
||||
|
||||
raise ValueError(f"Write workflow failed: {error_details}")
|
||||
|
||||
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
@@ -321,7 +361,8 @@ class MemoryAgentService:
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: str,
|
||||
config_id: Optional[str],
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> Dict:
|
||||
@@ -334,11 +375,14 @@ class MemoryAgentService:
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: User message
|
||||
history: Conversation history
|
||||
search_switch: Search mode switch
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Dict with 'answer' and 'intermediate_outputs' keys
|
||||
@@ -350,8 +394,18 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
|
||||
@@ -365,15 +419,19 @@ class MemoryAgentService:
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -387,8 +445,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
@@ -404,45 +460,52 @@ class MemoryAgentService:
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "config_id": config_id},
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Debug: log message type and content preview
|
||||
msg_type = msg.__class__.__name__
|
||||
content_preview = str(msg_content)[:200] if msg_content else "empty"
|
||||
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(msg_content, str):
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(msg_content)
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Debug: log what keys are in parsed
|
||||
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
|
||||
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
@@ -450,34 +513,14 @@ class MemoryAgentService:
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
elif isinstance(msg_content, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in msg_content:
|
||||
intermediate_data = msg_content['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in msg_content:
|
||||
for intermediate_data in msg_content['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
@@ -489,18 +532,57 @@ class MemoryAgentService:
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
try:
|
||||
message = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(message, dict) and message.get('status') != '':
|
||||
summary_result = message.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
if audit_logger:
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
@@ -612,19 +694,29 @@ class MemoryAgentService:
|
||||
else:
|
||||
return output
|
||||
|
||||
async def classify_message_type(self, message: str) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
message: User message to classify
|
||||
config_id: Configuration ID to load LLM model from database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
logger.info("Classifying message type")
|
||||
|
||||
status = await status_typle(message)
|
||||
# Load configuration to get LLM model ID
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
@@ -790,7 +882,9 @@ class MemoryAgentService:
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -801,6 +895,8 @@ class MemoryAgentService:
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选)
|
||||
- current_user_id: 当前登录用户的ID(保留参数)
|
||||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||||
- db: 数据库会话(可选)
|
||||
|
||||
返回格式:
|
||||
{
|
||||
@@ -817,7 +913,7 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 根据 end_user_id 获取 end_user_name
|
||||
try:
|
||||
if end_user_id:
|
||||
if end_user_id and db:
|
||||
from app.repositories import end_user_repository
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
@@ -862,15 +958,19 @@ class MemoryAgentService:
|
||||
|
||||
await connector.close()
|
||||
|
||||
if not statements:
|
||||
if not statements or not llm_id:
|
||||
result["tags"] = []
|
||||
if not llm_id and statements:
|
||||
logger.warning("llm_id not provided, skipping tag generation")
|
||||
else:
|
||||
# 构建摘要文本
|
||||
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
|
||||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||||
|
||||
# 使用LLM提取标签
|
||||
llm_client = get_llm_client()
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_id)
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
@@ -1032,4 +1132,69 @@ class MemoryAgentService:
|
||||
# "msg": "解析失败",
|
||||
# "error_code": "DOC_PARSE_ERROR",
|
||||
# "data": {"error": str(e)}
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
|
||||
通过以下流程获取配置:
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
# 1. 获取 end_user 及其 app_id
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
logger.warning(f"End user not found: {end_user_id}")
|
||||
raise ValueError(f"终端用户不存在: {end_user_id}")
|
||||
|
||||
app_id = end_user.app_id
|
||||
logger.debug(f"Found end_user app_id: {app_id}")
|
||||
|
||||
# 2. 获取该应用的最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||||
.order_by(AppRelease.version.desc())
|
||||
)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
|
||||
if not latest_release:
|
||||
logger.warning(f"No active release found for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
|
||||
|
||||
# 3. 从 config 中提取 memory_config_id
|
||||
config = latest_release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(latest_release.id),
|
||||
"release_version": latest_release.version,
|
||||
"memory_config_id": memory_config_id
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||
return result
|
||||
399
api/app/services/memory_config_service.py
Normal file
399
api/app/services/memory_config_service.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
def _validate_config_id(config_id):
|
||||
"""Validate configuration ID format."""
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {config_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
try:
|
||||
parsed_id = int(config_id.strip())
|
||||
if parsed_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {parsed_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
|
||||
Usage:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
"Starting memory configuration loading",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
"Configuration not found in database",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"config_id": validated_config_id,
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"service": service_name,
|
||||
},
|
||||
)
|
||||
raise ConfigurationError(
|
||||
f"Configuration {validated_config_id} not found in database"
|
||||
)
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
# Validate embedding model
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
|
||||
# Resolve LLM model
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Resolve optional rerank model
|
||||
rerank_uuid = None
|
||||
rerank_name = None
|
||||
if memory_config.rerank_id:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Get embedding model name
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
config_name=memory_config.config_name,
|
||||
workspace_id=workspace.id,
|
||||
workspace_name=workspace.name,
|
||||
tenant_id=workspace.tenant_id,
|
||||
llm_model_id=llm_uuid,
|
||||
llm_model_name=llm_name,
|
||||
embedding_model_id=embedding_uuid,
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.info(
|
||||
"Memory configuration loaded successfully",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": validated_config_id,
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(config.workspace_id),
|
||||
"load_result": "success",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load memory configuration",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
"load_result": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.error(f"Failed to load memory configuration {config_id}: {e}")
|
||||
if isinstance(e, (ConfigurationError, ValueError)):
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict:
|
||||
"""Get LLM model configuration by ID.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with model configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
}
|
||||
|
||||
def get_embedder_config(self, embedding_id: str) -> dict:
|
||||
"""Get embedding model configuration by ID.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with embedder configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_config(memory_config: MemoryConfig):
|
||||
"""Build ExtractionPipelineConfig from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing all pipeline settings.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings.
|
||||
"""
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
|
||||
dedup_config = DedupConfig(
|
||||
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
|
||||
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
|
||||
fuzzy_name_threshold_strict=memory_config.t_name_strict,
|
||||
fuzzy_type_threshold_strict=memory_config.t_type_strict,
|
||||
fuzzy_overall_threshold=memory_config.t_overall,
|
||||
)
|
||||
|
||||
stmt_config = StatementExtractionConfig(
|
||||
statement_granularity=memory_config.statement_granularity,
|
||||
include_dialogue_context=memory_config.include_dialogue_context,
|
||||
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
|
||||
)
|
||||
|
||||
forget_config = ForgettingEngineConfig(
|
||||
offset=memory_config.offset,
|
||||
lambda_time=memory_config.lambda_time,
|
||||
lambda_mem=memory_config.lambda_mem,
|
||||
)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pruning_config(memory_config: MemoryConfig) -> dict:
|
||||
"""Retrieve semantic pruning config from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing pruning settings.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str
|
||||
- pruning_threshold: float
|
||||
"""
|
||||
return {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
@@ -4,38 +4,37 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.models.user_model import User
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigPilotRun,
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
import uuid
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Load environment variables for Neo4j connector
|
||||
load_dotenv()
|
||||
@@ -247,7 +246,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
@@ -256,24 +254,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
# 步骤 1: 配置加载和验证(数据库优先)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
raise ValueError("未提供 payload.config_id,禁止启动试运行")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
@@ -281,12 +267,16 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
config_service = MemoryConfigService(self.db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(cid),
|
||||
service_name="MemoryStorageService.pilot_run_stream"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
raise RuntimeError(f"Configuration loading failed: {e}")
|
||||
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
@@ -307,13 +297,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.services.pilot_run_service import run_pilot_extraction
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||
await run_pilot_extraction(
|
||||
memory_config=memory_config,
|
||||
dialogue_text=dialogue_text,
|
||||
db=self.db,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
|
||||
219
api/app/services/pilot_run_service.py
Normal file
219
api/app/services/pilot_run_service.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Pilot Run Service - 试运行服务
|
||||
|
||||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
async def run_pilot_extraction(
|
||||
memory_config: MemoryConfig,
|
||||
dialogue_text: str,
|
||||
db: Session,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行试运行模式的知识提取流水线。
|
||||
|
||||
Args:
|
||||
memory_config: 从数据库加载的内存配置对象
|
||||
dialogue_text: 输入的对话文本
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典
|
||||
"""
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
neo4j_connector = None
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
client_factory = MemoryClientFactory(db)
|
||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 解析对话文本
|
||||
logger.info("Parsing dialogue text...")
|
||||
step_start = time.time()
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches
|
||||
if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=str(memory_config.workspace_id),
|
||||
user_id=str(memory_config.tenant_id),
|
||||
apply_id=str(memory_config.config_id),
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=memory_config.chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dlg in chunked_dialogs:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
config = get_pipeline_config(memory_config)
|
||||
logger.info(
|
||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if neo4j_connector:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|
||||
@@ -4,15 +4,15 @@ User Memory Service
|
||||
处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -284,8 +284,7 @@ class UserMemoryService:
|
||||
# 使用 end_user_id 调用分析函数
|
||||
try:
|
||||
logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要")
|
||||
result = await analytics_user_summary(end_user_id)
|
||||
summary = result.get("summary", "")
|
||||
summary = await generate_user_summary(end_user_id)
|
||||
|
||||
if not summary:
|
||||
logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空")
|
||||
|
||||
142
api/app/tasks.py
142
api/app/tasks.py
@@ -1,26 +1,30 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
import redis
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.db import get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.core.config import settings
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
import redis
|
||||
import requests
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
@@ -170,7 +174,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"""Celery task to process a read message via MemoryAgentService.
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: User message to process
|
||||
history: Conversation history
|
||||
search_switch: Search switch parameter
|
||||
@@ -184,9 +188,28 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Resolve config_id if None
|
||||
actual_config_id = config_id
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id)
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -217,11 +240,17 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
else:
|
||||
detailed_error = str(e)
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"error": detailed_error,
|
||||
"group_id": group_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -234,7 +263,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Optional configuration ID
|
||||
|
||||
@@ -246,9 +275,28 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Resolve config_id if None
|
||||
actual_config_id = config_id
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
service = MemoryAgentService()
|
||||
return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id)
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = MemoryAgentService()
|
||||
return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -279,11 +327,17 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
else:
|
||||
detailed_error = str(e)
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"error": detailed_error,
|
||||
"group_id": group_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -291,6 +345,27 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
}
|
||||
|
||||
|
||||
def reflection_engine() -> None:
|
||||
"""Empty function placeholder for timed background reflection.
|
||||
|
||||
Intentionally left blank; replace with real reflection logic later.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
||||
def reflection_timer_task() -> None:
|
||||
"""Periodic Celery task that invokes reflection_engine.
|
||||
|
||||
Raises an exception on failure.
|
||||
"""
|
||||
reflection_engine()
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
def check_read_service_task() -> Dict[str, str]:
|
||||
@@ -353,10 +428,10 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.services.memory_storage_service import search_all
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
@@ -465,9 +540,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行记忆缓存重新生成定时任务")
|
||||
@@ -645,9 +720,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.services.memory_reflection_service import (
|
||||
MemoryReflectionService,
|
||||
WorkspaceAppService,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
@@ -127,6 +127,7 @@ dependencies = [
|
||||
"uvicorn>=0.34.0",
|
||||
"celery>=5.5.2",
|
||||
"simpleeval>=1.0.3",
|
||||
"langchain-aws>=1.0.0a1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
71
api/uv.lock
generated
71
api/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = "==3.12.*"
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
@@ -7,6 +7,9 @@ resolution-markers = [
|
||||
"(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
]
|
||||
|
||||
[options]
|
||||
prerelease-mode = "allow"
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
version = "2.6.1"
|
||||
@@ -241,6 +244,34 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.42.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
{ name = "jmespath" },
|
||||
{ name = "s3transfer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/72/e236ca627bc0461710685f5b7438f759ef3b4106e0e08dda08513a6539ab/boto3-1.42.14.tar.gz", hash = "sha256:a5d005667b480c844ed3f814a59f199ce249d0f5669532a17d06200c0a93119c", size = 112825, upload-time = "2025-12-19T20:27:15.325Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/ba/c657ea6f6d63563cc46748202fccd097b51755d17add00ebe4ea27580d06/boto3-1.42.14-py3-none-any.whl", hash = "sha256:bfcc665227bb4432a235cb4adb47719438d6472e5ccbf7f09512046c3f749670", size = 140571, upload-time = "2025-12-19T20:27:13.316Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.42.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jmespath" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/3f/50c56f093c2c6ce6de1f579726598db1cf9a9cccd3bf8693f73b1cf5e319/botocore-1.42.14.tar.gz", hash = "sha256:cf5bebb580803c6cfd9886902ca24834b42ecaa808da14fb8cd35ad523c9f621", size = 14910547, upload-time = "2025-12-19T20:27:04.431Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/94/67a78a8d08359e779894d4b1672658a3c7fcce216b48f06dfbe1de45521d/botocore-1.42.14-py3-none-any.whl", hash = "sha256:efe89adfafa00101390ec2c371d453b3359d5f9690261bc3bd70131e0d453e8e", size = 14583247, upload-time = "2025-12-19T20:27:00.54Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "6.2.1"
|
||||
@@ -1114,6 +1145,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/9c/6753e6522b8d0ef07d3a3d239426669e984fb0eba15a315cdbc1253904e4/jiter-0.12.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24e864cb30ab82311c6425655b0cdab0a98c5d973b065c66a3f020740c2324c", size = 346110, upload-time = "2025-11-09T20:49:21.817Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jmespath"
|
||||
version = "1.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.5.2"
|
||||
@@ -1262,6 +1302,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/39/ed3121ea3a0c60a0cda6ea5c4c1cece013e8bbc9b18344ff3ae507728f98/langchain-1.1.3-py3-none-any.whl", hash = "sha256:e5b208ed93e553df4087117a40bd0d450f9095030a843cad35c53ff2814bf731", size = 102227, upload-time = "2025-12-08T19:31:47.246Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-aws"
|
||||
version = "1.0.0a1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "boto3" },
|
||||
{ name = "langchain-core" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/c3/a98c0849c13c6880b5629409cadb22d4070e9c611013da127be975f8c0dc/langchain_aws-1.0.0a1.tar.gz", hash = "sha256:3bb193a5fa915520c52bb47581e892d11ac4d114939a1b3ecfeca56fe153fff7", size = 121650, upload-time = "2025-09-18T20:52:36.098Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/7b/be49a224fe3aa07ed869801356f06e1d7a321bb7f22b6f7935dce86d258a/langchain_aws-1.0.0a1-py3-none-any.whl", hash = "sha256:24207d05c619ea61dfeab0a0f7086ae388cc3f2f5c03a8ae56b12d1b77d72585", size = 146839, upload-time = "2025-09-18T20:52:35.013Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-classic"
|
||||
version = "1.0.0"
|
||||
@@ -2825,6 +2880,7 @@ dependencies = [
|
||||
{ name = "json-repair" },
|
||||
{ name = "kombu" },
|
||||
{ name = "langchain" },
|
||||
{ name = "langchain-aws" },
|
||||
{ name = "langchain-community" },
|
||||
{ name = "langchain-mcp-adapters" },
|
||||
{ name = "langchain-ollama" },
|
||||
@@ -2949,6 +3005,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = "==0.53.0" },
|
||||
{ name = "kombu", specifier = "==5.5.4" },
|
||||
{ name = "langchain", specifier = ">=1.0.3" },
|
||||
{ name = "langchain-aws", specifier = ">=1.0.0a1" },
|
||||
{ name = "langchain-community", specifier = ">=0.3.31" },
|
||||
{ name = "langchain-mcp-adapters", specifier = ">=0.1.13" },
|
||||
{ name = "langchain-ollama" },
|
||||
@@ -3199,6 +3256,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/e5/8925a4208f131b218f9a7e459c0d6fcac8324ae35da269cb437894576366/ruamel_yaml_clib-0.2.15-cp312-cp312-win_amd64.whl", hash = "sha256:2b216904750889133d9222b7b873c199d48ecbb12912aca78970f84a5aa1a4bc", size = 119013, upload-time = "2025-11-16T16:13:32.164Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scikit-learn"
|
||||
version = "1.7.2"
|
||||
|
||||
Reference in New Issue
Block a user