refactor(memory): restructure memory system and improve configuration management

- Remove deprecated main.py entry point from memory module
- Reorganize imports across controllers and services for consistency
- Update emotion controller to pass db session instead of config_id to services
- Enhance memory agent controller with db session parameter for status_type and user_profile endpoints
- Refactor memory agent service to accept db parameter in classify_message_type method
- Improve configuration handling in celery_app by removing automatic database reload
- Update all memory-related services to use centralized config management
- Standardize import ordering and remove unused imports across 50+ files
- Add pilot_run_service for new pilot execution workflow
- Refactor extraction engine, reflection engine, and search services for better modularity
- Update LLM utilities and embedder configuration for improved flexibility
- Enhance type classifier and verification tools with better error handling
- Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
Ke Sun
2025-12-23 17:17:04 +08:00
parent 258b88276f
commit 283c64a358
58 changed files with 2171 additions and 1797 deletions

View File

@@ -1,9 +1,9 @@
import os import os
from datetime import timedelta from datetime import timedelta
from urllib.parse import quote from urllib.parse import quote
from celery import Celery
from app.core.config import settings from app.core.config import settings
from app.core.memory.utils.config.definitions import reload_configuration_from_database from celery import Celery
# 创建 Celery 应用实例 # 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0 # broker: 任务队列(使用 Redis DB 0
@@ -13,7 +13,6 @@ celery_app = Celery(
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
) )
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
# 配置使用本地队列,避免与远程 worker 冲突 # 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl' celery_app.conf.task_default_queue = 'localhost_test_wyl'
@@ -22,6 +21,7 @@ celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置 # macOS 兼容性配置
import platform import platform
if platform.system() == 'Darwin': # macOS if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题 # 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')

View File

@@ -10,22 +10,21 @@ Routes:
POST /emotion/suggestions - 获取个性化情绪建议 POST /emotion/suggestions - 获取个性化情绪建议
""" """
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user, get_db from app.dependencies import get_current_user, get_db
from app.models.user_model import User from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.emotion_schema import ( from app.schemas.emotion_schema import (
EmotionHealthRequest,
EmotionSuggestionsRequest,
EmotionTagsRequest, EmotionTagsRequest,
EmotionWordcloudRequest, EmotionWordcloudRequest,
EmotionHealthRequest,
EmotionSuggestionsRequest
) )
from app.schemas.response_schema import ApiResponse
from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.core.logging_config import get_api_logger from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
# 获取API专用日志器 # 获取API专用日志器
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -230,7 +229,7 @@ async def get_emotion_suggestions(
# 调用服务层 # 调用服务层
data = await emotion_service.generate_emotion_suggestions( data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.group_id, end_user_id=request.group_id,
config_id=config_id db=db
) )
api_logger.info( api_logger.info(

View File

@@ -164,6 +164,7 @@ async def write_server(
user_input.group_id, user_input.group_id,
user_input.message, user_input.message,
config_id, config_id,
db,
storage_type, storage_type,
user_rag_memory_id user_rag_memory_id
) )
@@ -280,6 +281,7 @@ async def read_server(
user_input.history, user_input.history,
user_input.search_switch, user_input.search_switch,
config_id, config_id,
db,
storage_type, storage_type,
user_rag_memory_id user_rag_memory_id
) )
@@ -548,6 +550,7 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse) @router.post("/status_type", response_model=ApiResponse)
async def status_type( async def status_type(
user_input: Write_UserInput, user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
@@ -561,7 +564,11 @@ async def status_type(
""" """
api_logger.info(f"Status type check requested for group {user_input.group_id}") api_logger.info(f"Status type check requested for group {user_input.group_id}")
try: try:
result = await memory_agent_service.classify_message_type(user_input.message) result = await memory_agent_service.classify_message_type(
user_input.message,
user_input.config_id,
db
)
return success(data=result) return success(data=result)
except Exception as e: except Exception as e:
api_logger.error(f"Message type classification failed: {str(e)}") api_logger.error(f"Message type classification failed: {str(e)}")
@@ -636,6 +643,7 @@ async def get_hot_memory_tags_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse) @router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api( async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
@@ -659,7 +667,8 @@ async def get_user_profile_api(
try: try:
result = await memory_agent_service.get_user_profile( result = await memory_agent_service.get_user_profile(
end_user_id=end_user_id, end_user_id=end_user_id,
current_user_id=str(current_user.id) current_user_id=str(current_user.id),
db=db
) )
return success(data=result, msg="获取用户详情成功") return success(data=result, msg="获取用户详情成功")
except Exception as e: except Exception as e:
@@ -695,3 +704,40 @@ async def get_user_profile_api(
# except Exception as e: # except Exception as e:
# api_logger.error(f"API docs retrieval failed: {str(e)}") # api_logger.error(f"API docs retrieval failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e)) # return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
Returns:
包含 memory_config_id 和相关信息的响应
"""
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
except ValueError as e:
api_logger.warning(f"End user config not found: {str(e)}")
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -1,22 +1,27 @@
import asyncio import asyncio
import time import time
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
ReflectionConfig,
ReflectionEngine,
)
from app.core.response_utils import success from app.core.response_utils import success
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
from app.dependencies import get_current_user
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
from app.schemas.memory_reflection_schemas import Memory_Reflection from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
MemoryReflectionService,
WorkspaceAppService,
)
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.orm import Session
load_dotenv() load_dotenv()
api_logger = get_api_logger() api_logger = get_api_logger()

View File

@@ -9,18 +9,19 @@ LangChain Agent 封装
""" """
import os import os
import time import time
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_konwledges_server import write_rag from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger() logger = get_business_logger()
@@ -198,10 +199,24 @@ class LangChainAgent:
""" """
message_chat= message message_chat= message
start_time = time.time() start_time = time.time()
if config_id == None: actual_config_id = config_id
actual_config_id = os.getenv("config_id") # If config_id is None, try to get from end_user's connected config
else: if actual_config_id is None and end_user_id:
actual_config_id = config_id try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown" actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
@@ -295,10 +310,24 @@ class LangChainAgent:
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80) logger.info("=" * 80)
message_chat = message message_chat = message
if config_id == None: actual_config_id = config_id
actual_config_id = os.getenv("config_id") # If config_id is None, try to get from end_user's connected config
else: if actual_config_id is None and end_user_id:
actual_config_id = config_id try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id) history_term_memory = await self.term_memory_redis_read(end_user_id)
if memory_flag: if memory_flag:

View File

@@ -1,7 +1,8 @@
import os
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional from typing import Any, Dict, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@@ -81,6 +82,7 @@ class Settings:
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query") VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
# Langfuse configuration # Langfuse configuration
LANGFUSE_ENABLED: bool = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "") LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "") LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "") LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
@@ -153,9 +155,6 @@ class Settings:
# Memory Module Configuration (internal) # Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
# Tool Management Configuration # Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
@@ -178,65 +177,6 @@ class Settings:
return str(base_path / filename) return str(base_path / filename)
return str(base_path) return str(base_path)
def get_memory_config_path(self, config_file: str = "") -> str:
"""
Get the full path for memory module configuration files.
Args:
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
Returns:
Full path to the config file
"""
if not config_file:
config_file = self.MEMORY_CONFIG_FILE
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
def load_memory_config(self) -> Dict[str, Any]:
"""
Load memory module configuration from config.json.
Returns:
Dictionary containing memory configuration
"""
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
try:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
return {}
def load_memory_runtime_config(self) -> Dict[str, Any]:
"""
Load memory module runtime configuration from runtime.json.
Returns:
Dictionary containing runtime configuration
"""
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
try:
with open(runtime_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
return {"selections": {}}
def load_memory_dbrun_config(self) -> Dict[str, Any]:
"""
Load memory module database run configuration from dbrun.json.
Returns:
Dictionary containing dbrun configuration
"""
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
return {"selections": {}}
def ensure_memory_output_dir(self) -> None: def ensure_memory_output_dir(self) -> None:
""" """
Ensure the memory output directory exists. Ensure the memory output directory exists.

View File

@@ -141,7 +141,7 @@ class SearchService:
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
try: try:
# Execute search using embedding_model_id from memory_config # Execute search using memory_config
answer = await run_hybrid_search( answer = await run_hybrid_search(
query_text=cleaned_query, query_text=cleaned_query,
search_type=search_type, search_type=search_type,
@@ -149,7 +149,7 @@ class SearchService:
limit=limit, limit=limit,
include=include, include=include,
output_path=output_path, output_path=output_path,
embedding_id=str(config.embedding_model_id), memory_config=config,
rerank_alpha=rerank_alpha, rerank_alpha=rerank_alpha,
) )

View File

@@ -13,7 +13,8 @@ from app.core.memory.agent.mcp_server.models.retrieval_models import (
) )
from app.core.memory.agent.mcp_server.server import get_context_resource from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.write_tools import write from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
@@ -41,8 +42,10 @@ async def Data_type_differentiation(
# Extract services from context # Extract services from context
template_service = get_context_resource(ctx, 'template_service') template_service = get_context_resource(ctx, 'template_service')
# Get LLM client from memory_config # Get LLM client from memory_config using factory pattern
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Render template # Render template
try: try:

View File

@@ -16,7 +16,8 @@ from app.core.memory.agent.mcp_server.models.problem_models import (
) )
from app.core.memory.agent.mcp_server.server import get_context_resource from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
@@ -56,7 +57,9 @@ async def Split_The_Problem(
session_service = get_context_resource(ctx, "session_service") session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config # Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Extract user ID from session # Extract user ID from session
user_id = session_service.resolve_user_id(sessionid) user_id = session_service.resolve_user_id(sessionid)
@@ -190,7 +193,9 @@ async def Problem_Extension(
session_service = get_context_resource(ctx, "session_service") session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config # Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages # Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username from app.core.memory.agent.utils.messages_tool import Resolve_username

View File

@@ -21,8 +21,9 @@ from app.core.memory.agent.utils.messages_tool import (
Resolve_username, Resolve_username,
Summary_messages_deal, Summary_messages_deal,
) )
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv from dotenv import load_dotenv
from mcp.server.fastmcp import Context from mcp.server.fastmcp import Context
@@ -66,7 +67,9 @@ async def Summary(
session_service = get_context_resource(ctx, "session_service") session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config # Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID # Resolve session ID
sessionid = Resolve_username(usermessages) sessionid = Resolve_username(usermessages)
@@ -210,7 +213,9 @@ async def Retrieve_Summary(
session_service = get_context_resource(ctx, "session_service") session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config # Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID # Resolve session ID
sessionid = Resolve_username(usermessages) sessionid = Resolve_username(usermessages)
@@ -425,7 +430,9 @@ async def Input_Summary(
search_service = get_context_resource(ctx, "search_service") search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config # Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID # Resolve session ID
sessionid = Resolve_username(usermessages) or "" sessionid = Resolve_username(usermessages) or ""

View File

@@ -1,15 +1,14 @@
""" """
Type classification utility for distinguishing read/write operations. Type classification utility for distinguishing read/write operations.
""" """
from jinja2 import Template from app.core.config import settings
from pydantic import BaseModel
from app.core.logging_config import get_agent_logger, log_prompt_rendering from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.config import settings from app.db import get_db_context
from jinja2 import Template
from pydantic import BaseModel
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -44,7 +43,9 @@ async def status_typle(messages: str, llm_model_id: str) -> dict:
"message": f"Prompt rendering failed: {str(e)}" "message": f"Prompt rendering failed: {str(e)}"
} }
llm_client = get_llm_client(llm_model_id) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_model_id)
try: try:
structured = await llm_client.response_structured( structured = await llm_client.response_structured(

View File

@@ -1,18 +1,19 @@
from typing import TypedDict, Annotated, List, Any
from langchain_core.messages import AnyMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
import asyncio import asyncio
import json import json
from dotenv import load_dotenv, find_dotenv
import os import os
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from typing import Annotated, Any, List, TypedDict
from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
# Removed global variable imports - use dependency injection instead # Removed global variable imports - use dependency injection instead
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from dotenv import find_dotenv, load_dotenv
from jinja2 import Environment, FileSystemLoader
from langchain_core.messages import AnyMessage, HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph, add_messages
load_dotenv(find_dotenv()) load_dotenv(find_dotenv())
@@ -53,7 +54,9 @@ class VerifyTool:
async def model_1(self, state: State) -> State: async def model_1(self, state: State) -> State:
if not self.llm_model_id: if not self.llm_model_id:
raise ValueError("llm_model_id is required but not provided") raise ValueError("llm_model_id is required but not provided")
llm_client = get_llm_client(self.llm_model_id) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response_content = await llm_client.chat( response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])] messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
) )

View File

@@ -13,13 +13,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
ExtractionOrchestrator, ExtractionOrchestrator,
) )
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation, memory_summary_generation,
) )
from app.core.memory.utils.embedder.embedder_utils import ( from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
get_embedder_client_from_config,
)
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.memory.utils.log.logging_utils import log_time from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
@@ -67,9 +65,11 @@ async def write(
logger.info(f"Chunker strategy: {chunker_strategy}") logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}") logger.info(f"Group ID: {group_id}")
# Construct clients from memory_config # Construct clients from memory_config using factory pattern with db session
llm_client = get_llm_client_from_config(memory_config) with get_db_context() as db:
embedder_client = get_embedder_client_from_config(memory_config) factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
embedder_client = factory.get_embedder_client_from_config(memory_config)
logger.info("LLM and embedding clients constructed") logger.info("LLM and embedding clients constructed")
# Initialize timing log # Initialize timing log
@@ -100,7 +100,7 @@ async def write(
# Step 2: Initialize and run ExtractionOrchestrator # Step 2: Initialize and run ExtractionOrchestrator
step_start = time.time() step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config from app.core.memory.utils.config.config_utils import get_pipeline_config
pipeline_config = get_pipeline_config() pipeline_config = get_pipeline_config(memory_config)
orchestrator = ExtractionOrchestrator( orchestrator = ExtractionOrchestrator(
llm_client=llm_client, llm_client=llm_client,
@@ -155,8 +155,8 @@ async def write(
# Step 4: Generate Memory summaries and save to Neo4j # Step 4: Generate Memory summaries and save to Neo4j
step_start = time.time() step_start = time.time()
try: try:
summaries = await Memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
) )
try: try:

View File

@@ -35,7 +35,9 @@ except NameError:
import json import json
from app.core.config import settings from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this #TODO: Fix this
# Default values (previously from definitions.py) # Default values (previously from definitions.py)
@@ -47,11 +49,37 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。""" """用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]: async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
""" """
使用LLM筛选标签列表仅保留具有代表性的核心名词。 使用LLM筛选标签列表仅保留具有代表性的核心名词。
""" """
try: try:
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
# 3. 构建Prompt # 3. 构建Prompt
tag_list_str = ", ".join(tags) tag_list_str = ", ".join(tags)
@@ -156,8 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
raw_tag_names = [tag for tag, freq in raw_tags_with_freq] raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
llm_client = get_llm_client(DEFAULT_LLM_ID) meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序 # 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = [] final_tags = []

View File

@@ -18,8 +18,10 @@ if src_path not in sys.path:
sys.path.insert(0, src_path) sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
#TODO: Fix this #TODO: Fix this
@@ -59,7 +61,33 @@ class MemoryInsight:
def __init__(self, user_id: str): def __init__(self, user_id: str):
self.user_id = user_id self.user_id = user_id
self.neo4j_connector = Neo4jConnector() self.neo4j_connector = Neo4jConnector()
self.llm_client = get_llm_client(DEFAULT_LLM_ID)
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self): async def close(self):
"""关闭数据库连接。""" """关闭数据库连接。"""

View File

@@ -25,8 +25,10 @@ except Exception:
pass pass
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this #TODO: Fix this
@@ -47,7 +49,33 @@ class UserSummary:
def __init__(self, user_id: str): def __init__(self, user_id: str):
self.user_id = user_id self.user_id = user_id
self.connector = Neo4jConnector() self.connector = Neo4jConnector()
self.llm = get_llm_client(DEFAULT_LLM_ID)
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self): async def close(self):
await self.connector.close() await self.connector.close()

View File

@@ -1,22 +1,34 @@
import os
import asyncio import asyncio
import json import json
from typing import List, Dict, Any, Optional import os
from datetime import datetime
import re import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import LLMClient from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker from app.core.memory.models.message_models import (
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage ConversationContext,
from app.repositories.neo4j.neo4j_connector import Neo4jConnector ConversationMessage,
from app.core.memory.utils.llm.llm_utils import get_llm_client DialogData,
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID )
# 使用新的模块化架构 # 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module # Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation # Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py # Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
@@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline(
llm_available = True llm_available = True
try: try:
from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
except Exception as e: except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}") print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
llm_available = False llm_available = False
@@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline(
return False return False
# 初始化 embedder 客户端 # 初始化 embedder 客户端
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
try: try:
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID) with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict) embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config) embedder_client = OpenAIEmbedderClient(embedder_config)
except Exception as e: except Exception as e:
@@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline(
print("[Ingestion] Generating memory summaries...") print("[Ingestion] Generating memory summaries...")
try: try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation, memory_summary_generation,
) )
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
summaries = await Memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list, chunked_dialogs=dialog_data_list,
llm_client=llm_client, llm_client=llm_client,
embedding_id=embedding_name or SELECTED_EMBEDDING_ID embedder_client=embedder_client
) )
print(f"[Ingestion] Generated {len(summaries)} memory summaries") print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e: except Exception as e:

View File

@@ -15,7 +15,7 @@ import json
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional from typing import Any, Dict, List, Optional
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -23,37 +23,38 @@ except ImportError:
def load_dotenv(): def load_dotenv():
pass pass
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
SELECTED_EMBEDDING_ID
)
from app.core.memory.utils.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import ( from app.core.memory.evaluation.common.metrics import (
f1_score, avg_context_tokens,
bleu1, bleu1,
f1_score,
jaccard, jaccard,
latency_stats, latency_stats,
avg_context_tokens
) )
from app.core.memory.evaluation.locomo.locomo_metrics import ( from app.core.memory.evaluation.locomo.locomo_metrics import (
get_category_name,
locomo_f1_score, locomo_f1_score,
locomo_multi_f1, locomo_multi_f1,
get_category_name
) )
from app.core.memory.evaluation.locomo.locomo_utils import ( from app.core.memory.evaluation.locomo.locomo_utils import (
load_locomo_data,
extract_conversations, extract_conversations,
ingest_conversations_if_needed,
load_locomo_data,
resolve_temporal_references, resolve_temporal_references,
select_and_format_information,
retrieve_relevant_information, retrieve_relevant_information,
ingest_conversations_if_needed select_and_format_information,
) )
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark( async def run_locomo_benchmark(
@@ -160,10 +161,16 @@ async def run_locomo_benchmark(
# Step 3: Initialize clients # Step 3: Initialize clients
print("🔧 Initializing clients...") print("🔧 Initializing clients...")
connector = Neo4jConnector() connector = Neo4jConnector()
llm_client = get_llm_client(SELECTED_LLM_ID)
# Initialize LLM client with database context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Initialize embedder # Initialize embedder
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )

View File

@@ -1,14 +1,16 @@
# file name: check_neo4j_connection_fixed.py # file name: check_neo4j_connection_fixed.py
import asyncio import asyncio
import os
import sys
import json import json
import time
import math import math
import os
import re import re
import sys
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import Any, Dict, List
from dotenv import load_dotenv from dotenv import load_dotenv
# 1 # 1
# 添加项目根目录到路径 # 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str:
# 尝试从 metrics.py 导入基础指标 # 尝试从 metrics.py 导入基础指标
try: try:
from common.metrics import f1_score, bleu1, jaccard from common.metrics import bleu1, f1_score, jaccard
print("✅ 从 metrics.py 导入基础指标成功") print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e: except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}") print(f"❌ 从 metrics.py 导入失败: {e}")
@@ -111,10 +113,14 @@ try:
# 尝试从不同位置导入 # 尝试从不同位置导入
try: try:
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功") print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError: except ImportError:
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功") print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e: except ImportError as e:
@@ -429,13 +435,17 @@ async def run_enhanced_evaluation():
return None return None
# 修正导入路径:使用 app.core.memory.src 前缀 # 修正导入路径:使用 app.core.memory.src 前缀
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.client_factory import MemoryClientFactory
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.db import get_db_context
from app.core.memory.utils.config.config_utils import get_embedder_config from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 加载数据 # 加载数据
# 获取项目根目录 # 获取项目根目录
@@ -458,10 +468,14 @@ async def run_enhanced_evaluation():
# 初始化增强监控器 # 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6) monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
llm = get_llm_client(SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder # 初始化embedder
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )

View File

@@ -2,10 +2,11 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import statistics
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import Any, Dict, List
import statistics
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
except Exception: except Exception:
@@ -13,16 +14,31 @@ except Exception:
return None return None
import re import re
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.storage_services.search import run_hybrid_search from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID from app.core.memory.utils.config.definitions import (
from app.core.memory.utils.llm.llm_utils import get_llm_client PROJECT_ROOT,
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline SELECTED_EMBEDDING_ID,
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现) # 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
@@ -327,9 +343,13 @@ async def run_locomo_eval(
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True) await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
# 使用异步LLM客户端 # 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用 # 初始化embedder用于直接调用
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )

View File

@@ -2,11 +2,11 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import time
import re import re
import statistics import statistics
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import Any, Dict, List
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -16,6 +16,7 @@ except Exception:
# 确保可以找到 src 及项目根路径 # 确保可以找到 src 及项目根路径
import sys import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR))) _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") _SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
@@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
# 与现有评估脚本保持一致的导入方式 # 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
try: try:
# 优先从 extraction_utils1 导入 # 优先从 extraction_utils1 导入
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline, # type: ignore
)
except Exception: except Exception:
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查 ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.evaluation.common.metrics import (
from app.core.models.base import RedBearModelConfig avg_context_tokens,
from app.core.memory.utils.config.config_utils import get_embedder_config jaccard,
from app.core.memory.utils.llm.llm_utils import get_llm_client latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.services.memory_config_service import MemoryConfigService
try: try:
from app.core.memory.evaluation.common.metrics import exact_match from app.core.memory.evaluation.common.metrics import exact_match
except Exception: except Exception:
@@ -686,9 +701,13 @@ async def run_longmemeval_test(
) )
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端 # 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector() connector = Neo4jConnector()
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )
@@ -748,10 +767,10 @@ async def run_longmemeval_test(
if stmt_text: if stmt_text:
contexts_all.append(stmt_text) contexts_all.append(stmt_text)
for sm in summaries: # for sm in summaries:
summary_text = str(sm.get("summary", "")).strip() # summary_text = str(sm.get("summary", "")).strip()
if summary_text: # if summary_text:
contexts_all.append(summary_text) # contexts_all.append(summary_text)
# 实体摘要最多3个 # 实体摘要最多3个
scored = [e for e in entities if e.get("score") is not None] scored = [e for e in entities if e.get("score") is not None]

View File

@@ -2,11 +2,11 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import time
import re import re
import statistics import statistics
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any from typing import Any, Dict, List
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -15,15 +15,26 @@ except Exception:
return None return None
# 与现有评估脚本保持一致的导入方式 # 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.client_factory import MemoryClientFactory
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.core.memory.evaluation.common.metrics import (
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient avg_context_tokens,
from app.core.models.base import RedBearModelConfig jaccard,
from app.core.memory.utils.config_utils import get_embedder_config latency_stats,
from app.core.memory.utils.llm_utils import get_llm_client )
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
try: try:
from app.core.memory.evaluation.common.metrics import exact_match from app.core.memory.evaluation.common.metrics import exact_match
except Exception: except Exception:
@@ -647,9 +658,13 @@ async def run_longmemeval_test(
items = qa_list[start_index:start_index + sample_size] items = qa_list[start_index:start_index + sample_size]
# 初始化组件 - 使用异步LLM客户端 # 初始化组件 - 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector() connector = Neo4jConnector()
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )

View File

@@ -4,19 +4,35 @@ import json
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
except Exception: except Exception:
def load_dotenv(): def load_dotenv():
return None return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID from app.core.memory.utils.config.definitions import (
from app.core.memory.utils.llm.llm_utils import get_llm_client PROJECT_ROOT,
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline SELECTED_EMBEDDING_ID,
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
@@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]: async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID group_id = group_id or SELECTED_GROUP_ID
# Load data # Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
@@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
await ingest_contexts_via_full_pipeline(contexts, group_id) await ingest_contexts_via_full_pipeline(contexts, group_id)
# LLM client (使用异步调用) # LLM client (使用异步调用)
llm_client = get_llm_client(SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Evaluate each item # Evaluate each item
connector = Neo4jConnector() connector = Neo4jConnector()
@@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
output_path=None, output_path=None,
memory_config=memory_config,
) )
except Exception: except Exception:
results = None results = None
@@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip()) pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference # Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference)) correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard from app.core.memory.evaluation.common.metrics import (
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference))) f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference))) b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference))) jss.append(jaccard(str(pred), str(reference)))

View File

@@ -2,10 +2,10 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import re
import time import time
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any from typing import Any, Dict, List
import re
try: try:
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -15,6 +15,7 @@ except Exception:
# 路径与模块导入保持与现有评估脚本一致 # 路径与模块导入保持与现有评估脚本一致
import sys import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) _PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") _SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
@@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
sys.path.insert(0, _p) sys.path.insert(0, _p)
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1 # 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.client_factory import MemoryClientFactory
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
try: try:
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
except Exception: except Exception:
# 兜底:简单实现(必要时) # 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float: def f1_score(pred: str, ref: str) -> float:
@@ -226,13 +237,17 @@ async def run_memsciqa_test(
items = all_items[start_index:start_index + sample_size] items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入 # 初始化 LLM纯测试不进行摄入
llm = get_llm_client(SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test # 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector() connector = Neo4jConnector()
embedder = None embedder = None
if search_type in ("embedding", "hybrid"): if search_type in ("embedding", "hybrid"):
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )

View File

@@ -5,18 +5,17 @@ OpenAI LLM 客户端实现
""" """
import asyncio import asyncio
from typing import List, Dict, Any
import json import json
import logging import logging
from typing import Any, Dict, List
from pydantic import BaseModel from app.core.config import settings
from langchain_core.prompts import ChatPromptTemplate from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
from langchain_core.output_parsers import PydanticOutputParser
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.models.llm import RedBearLLM from app.core.models.llm import RedBearLLM
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException from langchain_core.output_parsers import PydanticOutputParser
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,7 +42,7 @@ class OpenAIClient(LLMClient):
# 初始化 Langfuse 回调处理器(如果启用) # 初始化 Langfuse 回调处理器(如果启用)
self.langfuse_handler = None self.langfuse_handler = None
if LANGFUSE_ENABLED: if settings.LANGFUSE_ENABLED:
try: try:
from langfuse.langchain import CallbackHandler from langfuse.langchain import CallbackHandler
self.langfuse_handler = CallbackHandler() self.langfuse_handler = CallbackHandler()

View File

@@ -1,430 +0,0 @@
"""
MemSci 记忆系统主入口 - 重构版本
该模块是重构后的记忆系统主入口,使用新的模块化架构。
旧版本入口app/core/memory/src/main.py已删除。
主要功能:
1. 协调整个知识提取流水线
2. 支持试运行模式和正常运行模式
3. 使用重构后的 storage_services 模块
4. 提供统一的配置管理和日志记录
作者Lance77
日期2025-11-22
"""
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
import os
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional, Callable, Awaitable
from dotenv import load_dotenv
# 导入重构后的模块
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
# 导入数据加载函数
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
get_chunked_dialogs_with_preprocessing,
get_chunked_dialogs_from_preprocessed,
)
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.logging_config import get_memory_logger, log_time
load_dotenv()
logger = get_memory_logger(__name__)
async def main(
# Required configuration parameters (no longer from global variables)
chunker_strategy: str,
group_id: str,
user_id: str,
apply_id: str,
llm_model_id: str,
embedding_model_id: str,
# Optional parameters
dialogue_text: Optional[str] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
):
"""
记忆系统主流程 - 重构版本 (Updated to eliminate global variables)
该函数是重构后的主入口,使用新的模块化架构。
Global variables have been eliminated in favor of explicit parameters.
Args:
chunker_strategy: Chunking strategy to use (required)
group_id: Group ID for the operation (required)
user_id: User ID for the operation (required)
apply_id: Application ID for the operation (required)
llm_model_id: LLM model ID to use (required)
embedding_model_id: Embedding model ID to use (required)
dialogue_text: 输入的对话文本(可选,用于试运行模式)
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
progress_callback: 可选的进度回调函数
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
- 在管线关键点调用以报告进度和结果数据
工作流程:
1. 初始化客户端和配置
2. 加载或准备数据
3. 执行知识提取流水线
4. 保存结果(正常模式)或输出结果(试运行模式)
"""
print("=" * 60)
print("MemSci 知识提取流水线 - 重构版本")
print("=" * 60)
print(f"运行模式: {'试运行不保存到Neo4j' if is_pilot_run else '正常运行保存到Neo4j'}")
print("Using chunker strategy:", chunker_strategy)
print("Using group ID:", group_id)
print("Using model ID:", llm_model_id)
print("Using embedding model ID:", embedding_model_id)
print("=" * 60)
# 初始化日志
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
llm_client = get_llm_client(llm_model_id)
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
from app.core.models.base import RedBearModelConfig
embedder_config_dict = get_embedder_config(embedding_model_id)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 加载或准备数据
logger.info("Loading data...")
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
step_start = time.time()
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
# 试运行模式:处理前端传入的对话文本
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
import re
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
# 创建对话上下文和对话数据
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
logger.info("Loading data from testdata.json...")
test_data_path = os.path.join(
os.path.dirname(__file__), "data", "testdata.json"
)
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
indices=None,
input_data_path=test_data_path,
llm_client=llm_client,
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
# 从 runtime.json 加载配置(已经过数据库覆写)
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback, # 传递进度回调
embedding_id=embedding_model_id, # 传递嵌入模型ID
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
# 进度回调:正在知识抽取
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
)
# 解包 extraction_result tuple
# extraction_result 是一个包含 7 个元素的 tuple:
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
# statement_chunk_edges, statement_entity_edges, entity_edges)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 进度回调:生成结果
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 保存结果或输出结果
if is_pilot_run:
logger.info("Pilot run mode: Skipping Neo4j save")
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
print("提取结果已生成,可在相关输出中查看。")
else:
logger.info("Normal mode: Saving to Neo4j...")
step_start = time.time()
# 创建索引和约束
try:
from app.repositories.neo4j.create_indexes import (
create_fulltext_indexes,
create_unique_constraints,
)
await create_fulltext_indexes()
await create_unique_constraints()
logger.info("Successfully created indexes and constraints")
except Exception as e:
logger.error(f"Error creating indexes/constraints: {e}")
# 保存数据到 Neo4j
try:
from app.repositories.neo4j.graph_saver import (
save_dialog_and_statements_to_neo4j,
)
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
entity_edges=entity_edges,
connector=neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
print("\n✓ 成功保存所有数据到 Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
print("\n⚠ 部分数据保存到 Neo4j 失败")
except Exception as e:
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
print(f"\n✗ 保存到 Neo4j 失败: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# 步骤 6: 生成记忆摘要(可选)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import (
add_memory_summary_statement_edges,
)
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
)
if not is_pilot_run:
# 保存记忆摘要到 Neo4j
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
await ms_connector.close()
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
print(f"\n✗ 流水线执行失败: {e}")
raise
finally:
# 清理资源
try:
await neo4j_connector.close()
except Exception:
pass
# 记录总时间
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# 添加完成标记
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
print("\n" + "=" * 60)
print("✓ 流水线执行完成")
print(f"✓ 总耗时: {total_time:.2f}")
print(f"✓ 详细日志: {log_file}")
print("=" * 60)
if __name__ == "__main__":
print("⚠️ Warning: This script now requires explicit configuration parameters.")
print("Global variables have been removed. Please provide configuration parameters.")
print("Example usage:")
print(" asyncio.run(main(")
print(" chunker_strategy='RecursiveChunker',")
print(" group_id='your_group_id',")
print(" user_id='your_user_id',")
print(" apply_id='your_apply_id',")
print(" llm_model_id='your_llm_id',")
print(" embedding_model_id='your_embedding_id'")
print(" ))")
# This will fail because global variables are removed
raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.")

View File

@@ -20,14 +20,14 @@ Classes:
MemorySummaryNode: Node representing a memory summary MemorySummaryNode: Node representing a memory summary
""" """
from uuid import uuid4 import re
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field, field_validator from uuid import uuid4
import re
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.utils.alias_utils import validate_aliases from app.core.memory.utils.alias_utils import validate_aliases
from app.core.memory.utils.data.ontology import TemporalInfo
from pydantic import BaseModel, Field, field_validator
def parse_historical_datetime(v): def parse_historical_datetime(v):
@@ -361,7 +361,7 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity" description="Entity aliases - alternative names for this entity"
) )
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(..., description="Summary of the fact about this entity") fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -10,9 +10,10 @@ Classes:
""" """
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field
class Entity(BaseModel): class Entity(BaseModel):
"""Represents an extracted entity from dialogue. """Represents an extracted entity from dialogue.

View File

@@ -5,7 +5,10 @@ import math
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.logging_config import get_memory_logger from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
@@ -14,16 +17,14 @@ from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ( from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
ForgettingEngine, ForgettingEngine,
) )
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.config.config_utils import ( from app.core.memory.utils.config.config_utils import (
get_embedder_config,
get_pipeline_config, get_pipeline_config,
) )
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.utils.data.text_utils import extract_plain_query from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.utils.data.time_utils import normalize_date_safe from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.utils.llm.llm_utils import get_reranker_client from app.core.memory.utils.llm.llm_utils import get_reranker_client
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import ( from app.repositories.neo4j.graph_search import (
search_graph, search_graph,
search_graph_by_chunk_id, search_graph_by_chunk_id,
@@ -34,6 +35,7 @@ from app.repositories.neo4j.graph_search import (
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@@ -324,229 +326,229 @@ def apply_reranker_placeholder(
If config enables reranker, annotate items with a final_score equal to combined_score If config enables reranker, annotate items with a final_score equal to combined_score
and keep ordering. This is a no-op reranker to be replaced later. and keep ordering. This is a no-op reranker to be replaced later.
""" """
try: # try:
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})) # rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
except Exception as e: # except Exception as e:
logger.debug(f"Failed to load reranker config: {e}") # logger.debug(f"Failed to load reranker config: {e}")
rc = {} # rc = {}
if not rc or not rc.get("enabled", False): # if not rc or not rc.get("enabled", False):
return results # return results
top_k = int(rc.get("top_k", 100)) # top_k = int(rc.get("top_k", 100))
model_name = rc.get("model", "placeholder") # model_name = rc.get("model", "placeholder")
for cat, items in results.items(): # for cat, items in results.items():
head = items[:top_k] # head = items[:top_k]
for it in head: # for it in head:
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0) # base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
it["final_score"] = base # it["final_score"] = base
it["reranker_model"] = model_name # it["reranker_model"] = model_name
# Keep overall order by final_score if present, otherwise combined/score # # Keep overall order by final_score if present, otherwise combined/score
results[cat] = sorted( # results[cat] = sorted(
items, # items,
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)), # key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
reverse=True, # reverse=True,
) # )
return results return results
async def apply_llm_reranker( # async def apply_llm_reranker(
results: Dict[str, List[Dict[str, Any]]], # results: Dict[str, List[Dict[str, Any]]],
query_text: str, # query_text: str,
reranker_client: Optional[Any] = None, # reranker_client: Optional[Any] = None,
llm_weight: Optional[float] = None, # llm_weight: Optional[float] = None,
top_k: Optional[int] = None, # top_k: Optional[int] = None,
batch_size: Optional[int] = None, # batch_size: Optional[int] = None,
) -> Dict[str, List[Dict[str, Any]]]: # ) -> Dict[str, List[Dict[str, Any]]]:
""" # """
Apply LLM-based reranking to search results. # Apply LLM-based reranking to search results.
Args: # Args:
results: Search results organized by category # results: Search results organized by category
query_text: Original search query # query_text: Original search query
reranker_client: Optional pre-initialized reranker client # reranker_client: Optional pre-initialized reranker client
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) # llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
top_k: Maximum number of items to rerank per category # top_k: Maximum number of items to rerank per category
batch_size: Number of items to process concurrently # batch_size: Number of items to process concurrently
Returns: # Returns:
Reranked results with final_score and reranker_model fields # Reranked results with final_score and reranker_model fields
""" # """
# Load reranker configuration from runtime.json # # Load reranker configuration from runtime.json
try: # # try:
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}) # # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
except Exception as e: # # except Exception as e:
logger.debug(f"Failed to load reranker config: {e}") # # logger.debug(f"Failed to load reranker config: {e}")
rc = {} # # rc = {}
# Check if reranking is enabled # # Check if reranking is enabled
enabled = rc.get("enabled", False) # enabled = rc.get("enabled", False)
if not enabled: # if not enabled:
logger.debug("LLM reranking is disabled in configuration") # logger.debug("LLM reranking is disabled in configuration")
return results # return results
# Load configuration parameters with defaults # # Load configuration parameters with defaults
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) # llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
top_k = top_k if top_k is not None else rc.get("top_k", 20) # top_k = top_k if top_k is not None else rc.get("top_k", 20)
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) # batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# Initialize reranker client if not provided # # Initialize reranker client if not provided
if reranker_client is None: # if reranker_client is None:
try: # try:
reranker_client = get_reranker_client() # reranker_client = get_reranker_client()
except Exception as e: # except Exception as e:
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") # logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
return results # return results
# Get model name for metadata # # Get model name for metadata
model_name = getattr(reranker_client, 'model_name', 'unknown') # model_name = getattr(reranker_client, 'model_name', 'unknown')
# Process each category # # Process each category
reranked_results = {} # reranked_results = {}
for category in ["statements", "chunks", "entities", "summaries"]: # for category in ["statements", "chunks", "entities", "summaries"]:
items = results.get(category, []) # items = results.get(category, [])
if not items: # if not items:
reranked_results[category] = [] # reranked_results[category] = []
continue # continue
# Select top K items by combined_score for reranking # # Select top K items by combined_score for reranking
sorted_items = sorted( # sorted_items = sorted(
items, # items,
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), # key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
reverse=True # reverse=True
) # )
top_items = sorted_items[:top_k] # top_items = sorted_items[:top_k]
remaining_items = sorted_items[top_k:] # remaining_items = sorted_items[top_k:]
# Extract text content from each item # # Extract text content from each item
def extract_text(item: Dict[str, Any]) -> str: # def extract_text(item: Dict[str, Any]) -> str:
"""Extract text content from a result item.""" # """Extract text content from a result item."""
# Try different text fields based on category # # Try different text fields based on category
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" # text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
return str(text).strip() # return str(text).strip()
# Batch items for concurrent processing # # Batch items for concurrent processing
batches = [] # batches = []
for i in range(0, len(top_items), batch_size): # for i in range(0, len(top_items), batch_size):
batch = top_items[i:i + batch_size] # batch = top_items[i:i + batch_size]
batches.append(batch) # batches.append(batch)
# Process batches concurrently # # Process batches concurrently
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process a batch of items with LLM relevance scoring.""" # """Process a batch of items with LLM relevance scoring."""
scored_batch = [] # scored_batch = []
for item in batch: # for item in batch:
item_text = extract_text(item) # item_text = extract_text(item)
# Skip items with no text # # Skip items with no text
if not item_text: # if not item_text:
item_copy = item.copy() # item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score # item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = 0.0 # item_copy["llm_relevance_score"] = 0.0
item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
scored_batch.append(item_copy) # scored_batch.append(item_copy)
continue # continue
# Create relevance scoring prompt # # Create relevance scoring prompt
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. # prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
Query: {query_text} # Query: {query_text}
Result: {item_text} # Result: {item_text}
Respond with only a number between 0.0 and 1.0, where: # Respond with only a number between 0.0 and 1.0, where:
- 0.0 means completely irrelevant # - 0.0 means completely irrelevant
- 1.0 means perfectly relevant # - 1.0 means perfectly relevant
Relevance score:""" # Relevance score:"""
# Send request to LLM # # Send request to LLM
try: # try:
messages = [{"role": "user", "content": prompt}] # messages = [{"role": "user", "content": prompt}]
response = await reranker_client.chat(messages) # response = await reranker_client.chat(messages)
# Parse LLM response to extract relevance score # # Parse LLM response to extract relevance score
response_text = str(response.content if hasattr(response, 'content') else response).strip() # response_text = str(response.content if hasattr(response, 'content') else response).strip()
# Try to extract a float from the response # # Try to extract a float from the response
try: # try:
# Remove any non-numeric characters except decimal point # # Remove any non-numeric characters except decimal point
import re # import re
score_match = re.search(r'(\d+\.?\d*)', response_text) # score_match = re.search(r'(\d+\.?\d*)', response_text)
if score_match: # if score_match:
llm_score = float(score_match.group(1)) # llm_score = float(score_match.group(1))
# Clamp to [0.0, 1.0] # # Clamp to [0.0, 1.0]
llm_score = max(0.0, min(1.0, llm_score)) # llm_score = max(0.0, min(1.0, llm_score))
else: # else:
raise ValueError("No numeric score found in response") # raise ValueError("No numeric score found in response")
except (ValueError, AttributeError) as e: # except (ValueError, AttributeError) as e:
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") # logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
llm_score = None # llm_score = None
# Calculate final score # # Calculate final score
item_copy = item.copy() # item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
if llm_score is not None: # if llm_score is not None:
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score # final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
item_copy["llm_relevance_score"] = llm_score # item_copy["llm_relevance_score"] = llm_score
else: # else:
# Use combined_score as fallback # # Use combined_score as fallback
final_score = combined_score # final_score = combined_score
item_copy["llm_relevance_score"] = combined_score # item_copy["llm_relevance_score"] = combined_score
item_copy["final_score"] = final_score # item_copy["final_score"] = final_score
item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
scored_batch.append(item_copy) # scored_batch.append(item_copy)
except Exception as e: # except Exception as e:
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score") # logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
item_copy = item.copy() # item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score # item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = combined_score # item_copy["llm_relevance_score"] = combined_score
item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
scored_batch.append(item_copy) # scored_batch.append(item_copy)
return scored_batch # return scored_batch
# Process all batches concurrently # # Process all batches concurrently
try: # try:
batch_tasks = [process_batch(batch) for batch in batches] # batch_tasks = [process_batch(batch) for batch in batches]
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) # batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# Merge batch results # # Merge batch results
scored_items = [] # scored_items = []
for result in batch_results: # for result in batch_results:
if isinstance(result, Exception): # if isinstance(result, Exception):
logger.warning(f"Batch processing failed: {result}") # logger.warning(f"Batch processing failed: {result}")
continue # continue
scored_items.extend(result) # scored_items.extend(result)
# Add remaining items (not in top K) with their combined_score as final_score # # Add remaining items (not in top K) with their combined_score as final_score
for item in remaining_items: # for item in remaining_items:
item_copy = item.copy() # item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score # item_copy["final_score"] = combined_score
item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
scored_items.append(item_copy) # scored_items.append(item_copy)
# Sort all items by final_score in descending order # # Sort all items by final_score in descending order
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) # scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
reranked_results[category] = scored_items # reranked_results[category] = scored_items
except Exception as e: # except Exception as e:
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") # logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# Return original items with combined_score as final_score # # Return original items with combined_score as final_score
for item in items: # for item in items:
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item["final_score"] = combined_score # item["final_score"] = combined_score
item["reranker_model"] = model_name # item["reranker_model"] = model_name
reranked_results[category] = items # reranked_results[category] = items
return reranked_results # return reranked_results
async def run_hybrid_search( async def run_hybrid_search(
@@ -556,7 +558,7 @@ async def run_hybrid_search(
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
embedding_id: str, memory_config: "MemoryConfig",
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
@@ -564,11 +566,14 @@ async def run_hybrid_search(
""" """
Run search with specified type: 'keyword', 'embedding', or 'hybrid' Run search with specified type: 'keyword', 'embedding', or 'hybrid'
Args:
memory_config: MemoryConfig object containing embedding_model_id and config_id
""" """
# Start overall timing # Start overall timing
search_start_time = time.time() search_start_time = time.time()
latency_metrics = {} latency_metrics = {}
logger.info(f"using embedding_id:{embedding_id}...") logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
# Clean and normalize the incoming query before use/logging # Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text) query_text = extract_plain_query(query_text)
@@ -621,7 +626,9 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig # 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time() config_load_start = time.time()
embedder_config_dict = get_embedder_config(embedding_id) with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
rb_config = RedBearModelConfig( rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"], model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"], provider=embedder_config_dict["provider"],
@@ -683,7 +690,7 @@ async def run_hybrid_search(
if use_forgetting_rerank: if use_forgetting_rerank:
# Load forgetting parameters from pipeline config # Load forgetting parameters from pipeline config
try: try:
pc = get_pipeline_config() pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine forgetting_cfg = pc.forgetting_engine
except Exception as e: except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}") logger.debug(f"Failed to load forgetting config, using defaults: {e}")
@@ -711,16 +718,16 @@ async def run_hybrid_search(
# Apply LLM reranking if enabled # Apply LLM reranking if enabled
llm_rerank_applied = False llm_rerank_applied = False
if use_llm_rerank: # if use_llm_rerank:
try: # try:
reranked_results = await apply_llm_reranker( # reranked_results = await apply_llm_reranker(
results=reranked_results, # results=reranked_results,
query_text=query_text, # query_text=query_text,
) # )
llm_rerank_applied = True # llm_rerank_applied = True
logger.info("LLM reranking applied successfully") # logger.info("LLM reranking applied successfully")
except Exception as e: # except Exception as e:
logger.warning(f"LLM reranking failed: {e}, using previous scores") # logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results results["reranked_results"] = reranked_results
results["combined_summary"] = { results["combined_summary"] = {
@@ -896,90 +903,95 @@ async def search_chunk_by_chunk_id(
return {"chunks": chunks} return {"chunks": chunks}
def main(): # def main():
"""Main entry point for the hybrid graph search CLI. # """Main entry point for the hybrid graph search CLI.
Parses command line arguments and executes search with specified parameters. # Parses command line arguments and executes search with specified parameters.
Supports keyword, embedding, and hybrid search modes. # Supports keyword, embedding, and hybrid search modes.
""" # """
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") # parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
parser.add_argument( # parser.add_argument(
"--query", "-q", required=True, help="Free-text query to search" # "--query", "-q", required=True, help="Free-text query to search"
) # )
parser.add_argument( # parser.add_argument(
"--search-type", # "--search-type",
"-t", # "-t",
choices=["keyword", "embedding", "hybrid"], # choices=["keyword", "embedding", "hybrid"],
default="hybrid", # default="hybrid",
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" # help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
) # )
parser.add_argument( # parser.add_argument(
"--embedding-name", # "--config-id",
"-m", # "-c",
default="openai/nomic-embed-text:v1.5", # type=int,
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)", # required=True,
) # help="Database configuration ID (required)",
parser.add_argument( # )
"--group-id", # parser.add_argument(
"-g", # "--group-id",
default=None, # "-g",
help="Optional group_id to filter results (default: None)", # default=None,
) # help="Optional group_id to filter results (default: None)",
parser.add_argument( # )
"--limit", # parser.add_argument(
"-k", # "--limit",
type=int, # "-k",
default=5, # type=int,
help="Max number of results per type (default: 5)", # default=5,
) # help="Max number of results per type (default: 5)",
parser.add_argument( # )
"--include", # parser.add_argument(
"-i", # "--include",
nargs="+", # "-i",
default=["statements", "chunks", "entities", "summaries"], # nargs="+",
choices=["statements", "chunks", "entities", "summaries"], # default=["statements", "chunks", "entities", "summaries"],
help="Which targets to search for embedding search (default: statements chunks entities summaries)" # choices=["statements", "chunks", "entities", "summaries"],
) # help="Which targets to search for embedding search (default: statements chunks entities summaries)"
parser.add_argument( # )
"--output", # parser.add_argument(
"-o", # "--output",
default="search_results.json", # "-o",
help="Path to save the search results JSON (default: search_results.json)", # default="search_results.json",
) # help="Path to save the search results JSON (default: search_results.json)",
parser.add_argument( # )
"--rerank-alpha", # parser.add_argument(
"-a", # "--rerank-alpha",
type=float, # "-a",
default=0.6, # type=float,
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", # default=0.6,
) # help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
parser.add_argument( # )
"--forgetting-rerank", # parser.add_argument(
action="store_true", # "--forgetting-rerank",
help="Apply forgetting curve during reranking for hybrid search.", # action="store_true",
) # help="Apply forgetting curve during reranking for hybrid search.",
parser.add_argument( # )
"--llm-rerank", # parser.add_argument(
action="store_true", # "--llm-rerank",
help="Apply LLM-based reranking for hybrid search.", # action="store_true",
) # help="Apply LLM-based reranking for hybrid search.",
args = parser.parse_args() # )
# args = parser.parse_args()
asyncio.run( # # Load memory config from database
run_hybrid_search( # from app.services.memory_config_service import MemoryConfigService
query_text=args.query, # memory_config = MemoryConfigService.load_memory_config(args.config_id)
search_type=args.search_type,
group_id=args.group_id, # asyncio.run(
limit=args.limit, # run_hybrid_search(
include=args.include, # query_text=args.query,
output_path=args.output, # search_type=args.search_type,
embedding_id=config_defs.SELECTED_EMBEDDING_ID, # group_id=args.group_id,
rerank_alpha=args.rerank_alpha, # limit=args.limit,
use_forgetting_rerank=args.forgetting_rerank, # include=args.include,
use_llm_rerank=args.llm_rerank, # output_path=args.output,
) # memory_config=memory_config,
) # rerank_alpha=args.rerank_alpha,
# use_forgetting_rerank=args.forgetting_rerank,
# use_llm_rerank=args.llm_rerank,
# )
# )
if __name__ == "__main__": # if __name__ == "__main__":
main() # main()

View File

@@ -1,19 +1,22 @@
""" """
去重功能函数 去重功能函数
""" """
from app.core.memory.models.variate_config import DedupConfig
from typing import List, Dict, Tuple, Any
from app.core.memory.models.graph_models import(
StatementEntityEdge,
EntityEntityEdge,
ExtractedEntityNode
)
import os
from datetime import datetime
import difflib # 提供字符串相似度计算工具
import asyncio import asyncio
import difflib # 提供字符串相似度计算工具
import importlib import importlib
import os
import re import re
from datetime import datetime
from typing import Any, Dict, List, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.variate_config import DedupConfig
# 模块级类型统一工具函数 # 模块级类型统一工具函数
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None: def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
"""统一实体类型基于LLM建议或启发式规则选择最合适的类型。 """统一实体类型基于LLM建议或启发式规则选择最合适的类型。
@@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str], id_redirect: Dict[str, str],
config: DedupConfig | None = None, config: DedupConfig,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
""" """
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。 基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
@@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
""" """
llm_records: List[str] = [] llm_records: List[str] = []
try: try:
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量 if not bool(config.enable_llm_dedup_blockwise):
enable_switch = (
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
)
if not enable_switch:
return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
_defaults = DedupConfig()
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
# 动态导入 llm 客户端(修正导入路径)
try:
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
get_llm_client_fn = llm_utils_mod.get_llm_client
except Exception as e:
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
return deduped_entities, id_redirect, llm_records return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数
block_size = config.llm_block_size
block_concurrency = config.llm_block_concurrency
pair_concurrency = config.llm_pair_concurrency
max_rounds = config.llm_max_rounds
try: try:
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm") llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
@@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}") llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
return deduped_entities, id_redirect, llm_records return deduped_entities, id_redirect, llm_records
# 获取 LLM 客户端 # 验证 LLM 客户端
try: if llm_client is None:
llm_client = get_llm_client_fn() llm_records.append("[LLM错误] LLM 客户端未提供")
if llm_client is None:
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
return deduped_entities, id_redirect, llm_records
except Exception as e:
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
return deduped_entities, id_redirect, llm_records return deduped_entities, id_redirect, llm_records
llm_redirect, llm_records = await llm_fn( llm_redirect, llm_records = await llm_fn(
@@ -813,7 +799,8 @@ async def LLM_disamb_decision(
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str], id_redirect: Dict[str, str],
config: DedupConfig | None = None, config: DedupConfig,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
""" """
预消歧阶段对“同名但类型不同”的实体对调用LLM进行消歧 预消歧阶段对“同名但类型不同”的实体对调用LLM进行消歧
@@ -824,22 +811,16 @@ async def LLM_disamb_decision(
disamb_records: List[str] = [] disamb_records: List[str] = []
blocked_pairs: set[tuple[str, str]] = set() blocked_pairs: set[tuple[str, str]] = set()
try: try:
enable_switch = ( if not bool(config.enable_llm_disambiguation):
config.enable_llm_disambiguation
if config is not None
else DedupConfig().enable_llm_disambiguation
)
if not bool(enable_switch):
return deduped_entities, id_redirect, blocked_pairs, disamb_records return deduped_entities, id_redirect, blocked_pairs, disamb_records
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative llm_disambiguate_pairs_iterative,
from app.core.memory.utils.config import definitions as config_defs )
# 获取 LLM 客户端并验证 # 验证 LLM 客户端
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
if llm_client is None: if llm_client is None:
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None") disamb_records.append("[DISAMB错误] LLM 客户端未提供")
return deduped_entities, id_redirect, blocked_pairs, disamb_records return deduped_entities, id_redirect, blocked_pairs, disamb_records
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative( merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
@@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges(
report_append: bool = False, report_append: bool = False,
report_stage_notes: List[str] | None = None, report_stage_notes: List[str] | None = None,
dedup_config: DedupConfig | None = None, dedup_config: DedupConfig | None = None,
llm_client = None,
) -> Tuple[ ) -> Tuple[
List[ExtractedEntityNode], List[ExtractedEntityNode],
List[StatementEntityEdge], List[StatementEntityEdge],
@@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges(
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并 # 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision( deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
) )
# 2) 模糊匹配(本地规则) # 2) 模糊匹配(本地规则)
@@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges(
if should_trigger_llm: if should_trigger_llm:
deduped_entities, id_redirect, llm_decision_records = await LLM_decision( deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
) )
else: else:
llm_decision_records = [] llm_decision_records = []

View File

@@ -10,15 +10,27 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Dict, Any, Tuple
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.variate_config import DedupConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明
_write_dedup_fusion_report,
deduplicate_entities_and_edges,
)
from app.repositories.neo4j.graph_search import (
get_dedup_candidates_for_entities, # 导入ge函数用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
)
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互 from app.repositories.neo4j.neo4j_connector import (
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。 Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于 Neo4j 数据库进行交互
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge )
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
from app.core.memory.models.variate_config import DedupConfig
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt用于将任意类型的输入值解析为datetime对象处理实体节点中的时间字段 def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt用于将任意类型的输入值解析为datetime对象处理实体节点中的时间字段
@@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
dedup_config: DedupConfig | None = None, dedup_config: DedupConfig | None = None,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
""" """
第二层去重消歧: 第二层去重消歧:
@@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes) union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重) # 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges( fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges(
union_entities, union_entities,
statement_entity_edges, statement_entity_edges,
entity_entity_edges, entity_entity_edges,
report_stage="第二层去重消歧", report_stage="第二层去重消歧",
report_append=True, report_append=True,
dedup_config=dedup_config, dedup_config=dedup_config,
llm_client=llm_client,
) )
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges

View File

@@ -1,23 +1,27 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Tuple, Optional from typing import List, Optional, Tuple
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.models.graph_models import ( from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode, ChunkNode,
StatementNode, DialogueNode,
EntityEntityEdge,
ExtractedEntityNode, ExtractedEntityNode,
StatementChunkEdge, StatementChunkEdge,
StatementEntityEdge, StatementEntityEdge,
EntityEntityEdge, StatementNode,
) )
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return( async def dedup_layers_and_merge_and_return(
@@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return(
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
pipeline_config: Optional[ExtractionPipelineConfig] = None, pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None, connector: Optional[Neo4jConnector] = None,
llm_client = None,
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
@@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return(
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。 返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
""" """
# 默认从 runtime.json 加载管线配置,避免回退到环境变量 # pipeline_config is required - caller must provide it
if pipeline_config is None: if pipeline_config is None:
try: raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
pipeline_config = get_pipeline_config()
except Exception:
pipeline_config = None
# 先探测 group_id决定报告写入策略 # 先探测 group_id决定报告写入策略
group_id: Optional[str] = None group_id: Optional[str] = None
@@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return(
report_stage="第一层去重消歧", report_stage="第一层去重消歧",
report_append=False, report_append=False,
dedup_config=(pipeline_config.deduplication if pipeline_config else None), dedup_config=(pipeline_config.deduplication if pipeline_config else None),
llm_client=llm_client,
) )
# 初始化第二层融合结果为第一层结果 # 初始化第二层融合结果为第一层结果
@@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return(
statement_entity_edges=dedup_statement_entity_edges, statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges, entity_entity_edges=dedup_entity_entity_edges,
dedup_config=(pipeline_config.deduplication if pipeline_config else None), dedup_config=(pipeline_config.deduplication if pipeline_config else None),
llm_client=llm_client,
) )
else: else:
print("Skip second-layer dedup: missing connector") print("Skip second-layer dedup: missing connector")

View File

@@ -1253,6 +1253,7 @@ class ExtractionOrchestrator:
report_stage="第一层去重消歧(试运行)", report_stage="第一层去重消歧(试运行)",
report_append=False, report_append=False,
dedup_config=self.config.deduplication, dedup_config=self.config.deduplication,
llm_client=self.llm_client,
) )
# 保存去重消歧的详细记录到实例变量 # 保存去重消歧的详细记录到实例变量
@@ -1284,6 +1285,7 @@ class ExtractionOrchestrator:
dialog_data_list, dialog_data_list,
self.config, self.config,
self.connector, self.connector,
llm_client=self.llm_client,
) )
# 解包返回值 # 解包返回值

View File

@@ -5,11 +5,13 @@
""" """
import asyncio import asyncio
from typing import List, Dict, Any, Tuple from typing import Any, Dict, List, Tuple
from app.core.memory.models.message_models import DialogData
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
class EmbeddingGenerator: class EmbeddingGenerator:
@@ -21,7 +23,9 @@ class EmbeddingGenerator:
Args: Args:
embedding_id: 嵌入模型 ID embedding_id: 嵌入模型 ID
""" """
embedder_config = get_embedder_config(embedding_id) with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config = config_service.get_embedder_config(embedding_id)
self.embedder_client = OpenAIEmbedderClient( self.embedder_client = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(embedder_config), model_config=RedBearModelConfig.model_validate(embedder_config),
) )

View File

@@ -1,21 +1,17 @@
import os
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from uuid import uuid4
from pydantic import Field, field_validator
from app.core.logging_config import get_memory_logger from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.base_response import RobustLLMResponse
from app.core.memory.models.graph_models import MemorySummaryNode
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
from pydantic import Field
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
from app.core.memory.models.graph_models import MemorySummaryNode
from app.core.memory.models.base_response import RobustLLMResponse
from app.core.models.base import RedBearModelConfig
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
from uuid import uuid4
class MemorySummaryResponse(RobustLLMResponse): class MemorySummaryResponse(RobustLLMResponse):
@@ -91,22 +87,17 @@ async def _process_chunk_summary(
return None return None
async def Memory_summary_generation( async def memory_summary_generation(
chunked_dialogs: List[DialogData], chunked_dialogs: List[DialogData],
llm_client, llm_client,
embedding_id, embedder_client: OpenAIEmbedderClient,
) -> List[MemorySummaryNode]: ) -> List[MemorySummaryNode]:
"""Generate memory summaries per chunk, embed them, and return nodes.""" """Generate memory summaries per chunk, embed them, and return nodes."""
embedder_cfg_dict = get_embedder_config(embedding_id)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
)
# Collect all tasks for parallel processing # Collect all tasks for parallel processing
tasks = [] tasks = []
for dialog in chunked_dialogs: for dialog in chunked_dialogs:
for chunk in dialog.chunks: for chunk in dialog.chunks:
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder)) tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
# Process all chunks in parallel # Process all chunks in parallel
results = await asyncio.gather(*tasks, return_exceptions=False) results = await asyncio.gather(*tasks, return_exceptions=False)

View File

@@ -1,17 +1,21 @@
import os
import asyncio import asyncio
import logging import logging
from typing import List, Optional, Dict, Any import os
from pydantic import BaseModel, Field
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS,
RelevenceInfo,
StatementType,
TemporalInfo,
)
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo from pydantic import BaseModel, Field
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -8,21 +8,25 @@
4. 反思结果应用 - 更新记忆库 4. 反思结果应用 - 更新记忆库
""" """
import asyncio
import json import json
import logging import logging
import asyncio
import os import os
import time import time
from typing import List, Dict, Any, Optional
from enum import Enum
import uuid import uuid
from enum import Enum
from pydantic import BaseModel from typing import Any, Dict, List, Optional
from app.core.response_utils import success from app.core.response_utils import success
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all from app.repositories.neo4j.cypher_queries import (
from app.repositories.neo4j.neo4j_update import neo4j_data neo4j_query_all,
neo4j_query_part,
neo4j_statement_all,
neo4j_statement_part,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.neo4j_update import neo4j_data
from pydantic import BaseModel
# 配置日志 # 配置日志
_root_logger = logging.getLogger() _root_logger = logging.getLogger()
@@ -135,14 +139,20 @@ class ReflectionEngine:
self.neo4j_connector = Neo4jConnector() self.neo4j_connector = Neo4jConnector()
if self.llm_client is None: if self.llm_client is None:
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
elif isinstance(self.llm_client, str): elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端 # 如果 llm_client 是字符串model_id则用它初始化客户端
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
model_id = self.llm_client model_id = self.llm_client
self.llm_client = get_llm_client(model_id) with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(model_id)
if self.get_data_func is None: if self.get_data_func is None:
from app.core.memory.utils.config.get_data import get_data from app.core.memory.utils.config.get_data import get_data
@@ -154,11 +164,15 @@ class ReflectionEngine:
self.get_data_statement = get_data_statement self.get_data_statement = get_data_statement
if self.render_evaluate_prompt_func is None: if self.render_evaluate_prompt_func is None:
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt from app.core.memory.utils.prompt.template_render import (
render_evaluate_prompt,
)
self.render_evaluate_prompt_func = render_evaluate_prompt self.render_evaluate_prompt_func = render_evaluate_prompt
if self.render_reflexion_prompt_func is None: if self.render_reflexion_prompt_func is None:
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt from app.core.memory.utils.prompt.template_render import (
render_reflexion_prompt,
)
self.render_reflexion_prompt_func = render_reflexion_prompt self.render_reflexion_prompt_func = render_reflexion_prompt
if self.conflict_schema is None: if self.conflict_schema is None:
@@ -170,7 +184,9 @@ class ReflectionEngine:
self.reflexion_schema = ReflexionResultSchema self.reflexion_schema = ReflexionResultSchema
if self.update_query is None: if self.update_query is None:
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT from app.repositories.neo4j.cypher_queries import (
UPDATE_STATEMENT_INVALID_AT,
)
self.update_query = UPDATE_STATEMENT_INVALID_AT self.update_query = UPDATE_STATEMENT_INVALID_AT
self._lazy_init_done = True self._lazy_init_done = True

View File

@@ -4,10 +4,20 @@
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。 本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
""" """
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult from typing import TYPE_CHECKING
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.search_strategy import (
SearchResult,
SearchStrategy,
)
from app.core.memory.storage_services.search.semantic_search import (
SemanticSearchStrategy,
)
__all__ = [ __all__ = [
"SearchStrategy", "SearchStrategy",
@@ -34,7 +44,7 @@ async def run_hybrid_search(
include: list[str] | None = None, include: list[str] | None = None,
alpha: float = 0.6, alpha: float = 0.6,
use_forgetting_curve: bool = False, use_forgetting_curve: bool = False,
embedding_id: str | None = None, memory_config: "MemoryConfig" = None,
**kwargs **kwargs
) -> dict: ) -> dict:
"""运行混合搜索向后兼容的函数式API """运行混合搜索向后兼容的函数式API
@@ -51,24 +61,26 @@ async def run_hybrid_search(
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
alpha: BM25分数权重0.0-1.0 alpha: BM25分数权重0.0-1.0
use_forgetting_curve: 是否使用遗忘曲线 use_forgetting_curve: 是否使用遗忘曲线
embedding_id: 嵌入模型ID memory_config: MemoryConfig object containing embedding_model_id
**kwargs: 其他参数 **kwargs: 其他参数
Returns: Returns:
dict: 搜索结果字典格式与旧API兼容 dict: 搜索结果字典格式与旧API兼容
""" """
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 使用提供的embedding_id或默认值 if not memory_config:
emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID raise ValueError("memory_config is required for search")
# 初始化客户端 # 初始化客户端
connector = Neo4jConnector() connector = Neo4jConnector()
embedder_config_dict = get_embedder_config(emb_id) with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
embedder_config = RedBearModelConfig(**embedder_config_dict) embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config) embedder_client = OpenAIEmbedderClient(embedder_config)

View File

@@ -5,15 +5,20 @@
使用余弦相似度进行语义匹配。 使用余弦相似度进行语义匹配。
""" """
from typing import List, Dict, Any, Optional from typing import Any, Dict, List, Optional
from app.core.logging_config import get_memory_logger from app.core.logging_config import get_memory_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config from app.core.memory.storage_services.search.search_strategy import (
SearchResult,
SearchStrategy,
)
from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
@@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy):
""" """
try: try:
# 从数据库读取嵌入器配置 # 从数据库读取嵌入器配置
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
rb_config = RedBearModelConfig( rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"], model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"], provider=embedder_config_dict["provider"],

View File

@@ -9,7 +9,6 @@ from .config_utils import (
get_chunker_config, get_chunker_config,
get_embedder_config, get_embedder_config,
get_model_config, get_model_config,
get_neo4j_config,
get_picture_config, get_picture_config,
get_pipeline_config, get_pipeline_config,
get_pruning_config, get_pruning_config,
@@ -41,7 +40,6 @@ __all__ = [
# config_utils # config_utils
"get_model_config", "get_model_config",
"get_embedder_config", "get_embedder_config",
"get_neo4j_config",
"get_chunker_config", "get_chunker_config",
"get_pipeline_config", "get_pipeline_config",
"get_pruning_config", "get_pruning_config",

View File

@@ -1,90 +1,74 @@
from app.core.memory.models.variate_config import ( """
DedupConfig, Configuration utilities - Backward compatibility layer
ExtractionPipelineConfig,
ForgettingEngineConfig, DEPRECATED: These functions now require a db session parameter.
StatementExtractionConfig, New code should use MemoryConfigService(db) instance directly.
)
from app.core.memory.utils.config.definitions import CONFIG For functions that don't require db (get_pipeline_config, get_pruning_config),
from app.db import get_db they are still re-exported here.
from app.models.models_model import ModelApiKey """
from app.services.model_service import ModelConfigService
from fastapi import status import warnings
from fastapi.exceptions import HTTPException
from sqlalchemy.orm import Session from app.services.memory_config_service import MemoryConfigService
# These functions don't require db - safe to re-export as static methods
get_pipeline_config = MemoryConfigService.get_pipeline_config
get_pruning_config = MemoryConfigService.get_pruning_config
def get_model_config(model_id: str, db: Session | None = None) -> dict: def get_model_config(model_id: str, db=None):
"""DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly."""
if db is None: if db is None:
db_gen = get_db() # get_db 通常是一个生成器 raise ValueError(
db = next(db_gen) # 取到真正的 Session "get_model_config now requires a db session. "
"Use MemoryConfigService(db).get_model_config(model_id) directly."
)
return MemoryConfigService(db).get_model_config(model_id)
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
print(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
# 从环境变量读取超时和重试配置 def get_embedder_config(embedding_id: str, db=None):
from app.core.config import settings """DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."""
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
"type": config.type,
# 添加超时和重试配置,避免 LLM 请求超时
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取默认120秒
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取默认2次
}
# 写入model_config.log文件中
with open("logs/model_config.log", "a", encoding="utf-8") as f:
f.write(f"模型ID: {model_id}\n")
f.write(f"模型配置信息:\n{model_config}\n")
f.write("=============================\n\n")
return model_config
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
if db is None: if db is None:
db_gen = get_db() # get_db 通常是一个生成器 raise ValueError(
db = next(db_gen) # 取到真正的 Session "get_embedder_config now requires a db session. "
"Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."
)
return MemoryConfigService(db).get_embedder_config(embedding_id)
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
if not config:
print(f"嵌入模型ID {embedding_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
# Ensure required field for RedBearModelConfig validation
"type": config.type,
# 添加超时和重试配置,避免嵌入服务请求超时
"timeout": 120.0, # 嵌入服务超时时间(秒)
"max_retries": 5, # 最大重试次数
}
# 写入embedder_config.log文件中
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
f.write(f"嵌入模型ID: {embedding_id}\n")
f.write(f"嵌入模型配置信息:\n{model_config}\n")
f.write("=============================\n\n")
return model_config
def get_neo4j_config() -> dict:
"""Retrieves the Neo4j configuration from the config file."""
return CONFIG.get("neo4j", {})
def get_picture_config(llm_name: str) -> dict: def get_picture_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file.""" """Retrieves the configuration for a specific model from the config file.
.. deprecated::
This function is deprecated and will be removed in a future version.
Use database-backed model configuration instead.
"""
warnings.warn(
"get_picture_config is deprecated and will be removed in a future version. "
"Use database-backed model configuration instead.",
DeprecationWarning,
stacklevel=2
)
for model_config in CONFIG.get("picture_recognition", []): for model_config in CONFIG.get("picture_recognition", []):
if model_config["llm_name"] == llm_name: if model_config["llm_name"] == llm_name:
return model_config return model_config
raise ValueError(f"Model '{llm_name}' not found in config.json") raise ValueError(f"Model '{llm_name}' not found in config.json")
def get_voice_config(llm_name: str) -> dict: def get_voice_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file.""" """Retrieves the configuration for a specific model from the config file.
.. deprecated::
This function is deprecated and will be removed in a future version.
Use database-backed model configuration instead.
"""
warnings.warn(
"get_voice_config is deprecated and will be removed in a future version. "
"Use database-backed model configuration instead.",
DeprecationWarning,
stacklevel=2
)
for model_config in CONFIG.get("voice_recognition", []): for model_config in CONFIG.get("voice_recognition", []):
if model_config["llm_name"] == llm_name: if model_config["llm_name"] == llm_name:
return model_config return model_config
@@ -92,19 +76,8 @@ def get_voice_config(llm_name: str) -> dict:
def get_chunker_config(chunker_strategy: str) -> dict: def get_chunker_config(chunker_strategy: str) -> dict:
"""Retrieves the configuration for a specific chunker strategy. """Retrieves the configuration for a specific chunker strategy."""
Enhancements:
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
- Falls back to the first available chunker config when the requested one is missing.
"""
# 1) Try to find exact match in config
chunker_list = CONFIG.get("chunker_list", [])
for chunker_config in chunker_list:
if chunker_config.get("chunker_strategy") == chunker_strategy:
return chunker_config
# 2) Provide sane defaults for newer strategies
default_configs = { default_configs = {
"RecursiveChunker": { "RecursiveChunker": {
"chunker_strategy": "RecursiveChunker", "chunker_strategy": "RecursiveChunker",
@@ -112,7 +85,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
"chunk_size": 512, "chunk_size": 512,
"min_characters_per_chunk": 50 "min_characters_per_chunk": 50
}, },
"LLMChunker": { "LLMChunker": {
"chunker_strategy": "LLMChunker", "chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3", "embedding_model": "BAAI/bge-m3",
@@ -137,127 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
if chunker_strategy in default_configs: if chunker_strategy in default_configs:
return default_configs[chunker_strategy] return default_configs[chunker_strategy]
# 3) Fallback: use first available config but tag with requested strategy
if chunker_list:
fallback = chunker_list[0].copy()
fallback["chunker_strategy"] = chunker_strategy
# Non-fatal notice for visibility in logs if any
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
return fallback
# 4) If no configs available at all
raise ValueError( raise ValueError(
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available" f"Chunker '{chunker_strategy}' not found "
) )
#TODO: Fix this
def get_pipeline_config(
config_id: int,
db: Session | None = None,
) -> ExtractionPipelineConfig:
"""Build ExtractionPipelineConfig from database.
Args:
config_id: Database configuration ID (required). Loads pipeline
settings from the DataConfig table.
db: Optional database session. If not provided, a new session
will be created.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings loaded from database.
Raises:
ValueError: If config_id not found in database.
"""
from app.repositories.data_config_repository import DataConfigRepository
# Load from database
if db is None:
db_gen = get_db()
db = next(db_gen)
db_config = DataConfigRepository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"Configuration {config_id} not found in database")
# Build DedupConfig from database
dedup_kwargs = {
"enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False,
"enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False,
}
# Fuzzy thresholds
if db_config.t_name_strict is not None:
dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict
if db_config.t_type_strict is not None:
dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict
if db_config.t_overall is not None:
dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall
dedup_config = DedupConfig(**dedup_kwargs)
# Build StatementExtractionConfig from database
stmt_kwargs = {}
if db_config.statement_granularity is not None:
stmt_kwargs["statement_granularity"] = db_config.statement_granularity
if db_config.include_dialogue_context is not None:
stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context)
if db_config.max_context is not None:
stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context
stmt_config = StatementExtractionConfig(**stmt_kwargs)
# Build ForgettingEngineConfig from database
forget_kwargs = {}
if db_config.offset is not None:
forget_kwargs["offset"] = db_config.offset
if db_config.lambda_time is not None:
forget_kwargs["lambda_time"] = db_config.lambda_time
if db_config.lambda_mem is not None:
forget_kwargs["lambda_mem"] = db_config.lambda_mem
forget_config = ForgettingEngineConfig(**forget_kwargs)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
def get_pruning_config(
config_id: int,
db: Session | None = None,
) -> dict:
"""Retrieve semantic pruning config from database.
Args:
config_id: Database configuration ID (required).
db: Optional database session.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str ("education" | "online_service" | "outbound")
- pruning_threshold: float (0-0.9)
Raises:
ValueError: If config_id not found in database.
"""
from app.repositories.data_config_repository import DataConfigRepository
if db is None:
db_gen = get_db()
db = next(db_gen)
db_config = DataConfigRepository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"Configuration {config_id} not found in database")
return {
"pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False,
"pruning_scene": db_config.pruning_scene or "education",
"pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5,
}

View File

@@ -1,269 +1,268 @@
""" # """
配置加载模块 - DEPRECATED # 配置加载模块 - DEPRECATED
⚠️ DEPRECATION NOTICE ⚠️ # ⚠️ DEPRECATION NOTICE ⚠️
This module is deprecated and will be removed in a future version. # This module is deprecated and will be removed in a future version.
Global configuration variables have been eliminated in favor of dependency injection. # Global configuration variables have been eliminated in favor of dependency injection.
Use the new MemoryConfig system instead: # Use the new MemoryConfig system instead:
- app.core.memory_config.config.MemoryConfig for configuration objects # - app.schemas.memory_config_schema.MemoryConfig for configuration objects
- app.services.memory_agent_service.MemoryAgentService.load_memory_config() # - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED # 阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED
阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED # 阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED # 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
""" # """
import json # import json
import os # import os
import threading # import threading
from datetime import datetime, timedelta # from datetime import datetime, timedelta
from typing import Any, Dict, Optional # from typing import Any, Dict, Optional
#TODO: Fix this # #TODO: Fix this
try: # try:
from dotenv import load_dotenv # from dotenv import load_dotenv
load_dotenv() # load_dotenv()
except Exception: # except Exception:
pass # pass
# Import unified configuration system # # Import unified configuration system
try: # try:
from app.core.config import settings # from app.core.config import settings
USE_UNIFIED_CONFIG = True # USE_UNIFIED_CONFIG = True
except ImportError: # except ImportError:
USE_UNIFIED_CONFIG = False # USE_UNIFIED_CONFIG = False
settings = None # settings = None
# PROJECT_ROOT 应该指向 app/core/memory/ 目录 # # PROJECT_ROOT 应该指向 app/core/memory/ 目录
# __file__ = app/core/memory/utils/config/definitions.py # # __file__ = app/core/memory/utils/config/definitions.py
# os.path.dirname(__file__) = app/core/memory/utils/config # # os.path.dirname(__file__) = app/core/memory/utils/config
# os.path.dirname(...) = app/core/memory/utils # # os.path.dirname(...) = app/core/memory/utils
# os.path.dirname(...) = app/core/memory # # os.path.dirname(...) = app/core/memory
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# DEPRECATED: Global configuration lock removed # # DEPRECATED: Global configuration lock removed
# Use MemoryConfig objects with dependency injection instead # # Use MemoryConfig objects with dependency injection instead
# DEPRECATED: Legacy config.json loading removed # # DEPRECATED: Legacy config.json loading removed
# Use MemoryConfig objects with dependency injection instead # # Use MemoryConfig objects with dependency injection instead
CONFIG = {} # CONFIG = {}
DEFAULT_VALUES = { # DEFAULT_VALUES = {
"llm_name": "openai/qwen-plus", # "llm_name": "openai/qwen-plus",
"embedding_name": "openai/nomic-embed-text:v1.5", # "embedding_name": "openai/nomic-embed-text:v1.5",
"chunker_strategy": "RecursiveChunker", # "chunker_strategy": "RecursiveChunker",
"group_id": "group_123", # "group_id": "group_123",
"user_id": "default_user", # "user_id": "default_user",
"apply_id": "default_apply", # "apply_id": "default_apply",
"llm_agent_name": "openai/qwen-plus", # "llm_agent_name": "openai/qwen-plus",
"llm_verify_name": "openai/qwen-plus", # "llm_verify_name": "openai/qwen-plus",
"llm_image_recognition": "openai/qwen-plus", # "llm_image_recognition": "openai/qwen-plus",
"llm_voice_recognition": "openai/qwen-plus", # "llm_voice_recognition": "openai/qwen-plus",
"prompt_level": "DEBUG", # "prompt_level": "DEBUG",
"reflexion_iteration_period": "3", # "reflexion_iteration_period": "3",
"reflexion_range": "retrieval", # "reflexion_range": "retrieval",
"reflexion_baseline": "TIME", # "reflexion_baseline": "TIME",
} # }
# DEPRECATED: Legacy global variables for backward compatibility only # # DEPRECATED: Legacy global variables for backward compatibility only
# These will be removed in a future version # # These will be removed in a future version
# Use MemoryConfig objects with dependency injection instead # # Use MemoryConfig objects with dependency injection instead
LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" # # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"]) # # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
# 阶段 1: 从 runtime.json 加载配置(路径 A # # 阶段 1: 从 runtime.json 加载配置(路径 A
def _load_from_runtime_json() -> Dict[str, Any]: # def _load_from_runtime_json() -> Dict[str, Any]:
""" # """
DEPRECATED: Legacy runtime.json loading # DEPRECATED: Legacy runtime.json loading
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead. # Use MemoryConfig objects with dependency injection instead.
Returns: # Returns:
Dict[str, Any]: Empty configuration (legacy support only) # Dict[str, Any]: Empty configuration (legacy support only)
""" # """
import warnings # import warnings
warnings.warn( # warnings.warn(
"Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.", # "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
return {"selections": {}} # return {"selections": {}}
# 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器 # # 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 # # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# 保留此函数仅为向后兼容 # # 保留此函数仅为向后兼容
def _load_from_database() -> Optional[Dict[str, Any]]: # def _load_from_database() -> Optional[Dict[str, Any]]:
""" # """
DEPRECATED: Legacy database configuration loading # DEPRECATED: Legacy database configuration loading
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead. # Use MemoryConfig objects with dependency injection instead.
Returns: # Returns:
Optional[Dict[str, Any]]: None (deprecated functionality) # Optional[Dict[str, Any]]: None (deprecated functionality)
""" # """
import warnings # import warnings
warnings.warn( # warnings.warn(
"Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.", # "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
return None # return None
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED # # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: # def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
""" # """
DEPRECATED: 将运行时配置暴露为全局常量供项目使用 # DEPRECATED: 将运行时配置暴露为全局常量供项目使用
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Global configuration variables have been eliminated in favor of dependency injection. # Global configuration variables have been eliminated in favor of dependency injection.
Use the new MemoryConfig system instead: # Use the new MemoryConfig system instead:
- app.core.memory_config.config.MemoryConfig for configuration objects # - app.core.memory_config.config.MemoryConfig for configuration objects
- Pass configuration objects as parameters instead of using global variables # - Pass configuration objects as parameters instead of using global variables
Args: # Args:
runtime_cfg: 运行时配置字典 # runtime_cfg: 运行时配置字典
""" # """
import warnings # import warnings
warnings.warn( # warnings.warn(
"Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.", # "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
# Keep minimal global state for backward compatibility only # # Keep minimal global state for backward compatibility only
# These will be removed in a future version # # These will be removed in a future version
global RUNTIME_CONFIG, SELECTIONS # global RUNTIME_CONFIG, SELECTIONS
RUNTIME_CONFIG = runtime_cfg # RUNTIME_CONFIG = runtime_cfg
SELECTIONS = RUNTIME_CONFIG.get("selections", {}) # SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# All other global variables have been removed # # All other global variables have been removed
# Use MemoryConfig objects instead # # Use MemoryConfig objects instead
# 初始化:使用统一配置加载器 # # 初始化:使用统一配置加载器
def _initialize_configuration() -> None: # def _initialize_configuration() -> None:
""" # """
DEPRECATED: Legacy configuration initialization # DEPRECATED: Legacy configuration initialization
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead. # Use MemoryConfig objects with dependency injection instead.
""" # """
import warnings # import warnings
warnings.warn( # warnings.warn(
"Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.", # "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
# Initialize with empty configuration for backward compatibility # # Initialize with empty configuration for backward compatibility
_expose_runtime_constants({"selections": {}}) # _expose_runtime_constants({"selections": {}})
# 模块加载时自动初始化配置 # # 模块加载时自动初始化配置
_initialize_configuration() # _initialize_configuration()
# DEPRECATED: Global variables removed # # DEPRECATED: Global variables removed
# These variables have been eliminated in favor of dependency injection # # These variables have been eliminated in favor of dependency injection
# Use MemoryConfig objects instead of accessing global variables # # Use MemoryConfig objects instead of accessing global variables
# 公共 API动态重新加载配置 # # 公共 API动态重新加载配置
def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool: # def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
""" # """
DEPRECATED: Legacy configuration reloading # DEPRECATED: Legacy configuration reloading
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead. # Use MemoryConfig objects with dependency injection instead.
For new code, use: # For new code, use:
- app.services.memory_agent_service.MemoryAgentService.load_memory_config() # - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config() # - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
Args: # Args:
config_id: Configuration ID (deprecated) # config_id: Configuration ID (deprecated)
force_reload: Force reload flag (deprecated) # force_reload: Force reload flag (deprecated)
Returns: # Returns:
bool: Always returns False (deprecated functionality) # bool: Always returns False (deprecated functionality)
""" # """
import logging # import logging
import warnings # import warnings
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
warnings.warn( # warnings.warn(
"reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.", # "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. " # logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
"Use MemoryConfig objects with dependency injection instead.") # "Use MemoryConfig objects with dependency injection instead.")
return False # return False
def get_current_config_id() -> Optional[str]: # def get_current_config_id() -> Optional[str]:
""" # """
DEPRECATED: Legacy config ID retrieval # DEPRECATED: Legacy config ID retrieval
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead. # Use MemoryConfig objects with dependency injection instead.
Returns: # Returns:
Optional[str]: None (deprecated functionality) # Optional[str]: None (deprecated functionality)
""" # """
import warnings # import warnings
warnings.warn( # warnings.warn(
"get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.", # "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
return None # return None
def ensure_fresh_config(config_id = None) -> bool: # def ensure_fresh_config(config_id = None) -> bool:
""" # """
DEPRECATED: Legacy configuration freshness check # DEPRECATED: Legacy configuration freshness check
⚠️ This function is deprecated and will be removed in a future version. # ⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead. # Use MemoryConfig objects with dependency injection instead.
For new code, use: # For new code, use:
- app.services.memory_agent_service.MemoryAgentService.load_memory_config() # - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config() # - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
Args: # Args:
config_id: Configuration ID (deprecated) # config_id: Configuration ID (deprecated)
Returns: # Returns:
bool: Always returns False (deprecated functionality) # bool: Always returns False (deprecated functionality)
""" # """
import logging # import logging
import warnings # import warnings
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
warnings.warn( # warnings.warn(
"ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.", # "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning, # DeprecationWarning,
stacklevel=2 # stacklevel=2
) # )
logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. " # logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
"Use MemoryConfig objects with dependency injection instead.") # "Use MemoryConfig objects with dependency injection instead.")
return False # return False

View File

@@ -6,8 +6,9 @@ This module provides centralized functions for creating embedder clients.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
if TYPE_CHECKING: if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -66,7 +67,8 @@ def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
raise ValueError("Embedding ID is required but was not provided") raise ValueError("Embedding ID is required but was not provided")
try: try:
embedder_config_dict = get_embedder_config(embedding_id) with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id)
except Exception as e: except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e

View File

@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
if TYPE_CHECKING: if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -13,105 +13,225 @@ async def handle_response(response: type[BaseModel]) -> dict:
return response.model_dump() return response.model_dump()
def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient: class MemoryClientFactory:
""" """
Get LLM client from MemoryConfig object. Factory for creating LLM, embedder, and reranker clients.
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object. Initialize once with db session, then call methods without passing db each time.
This ensures proper configuration management and multi-tenant support.
Args:
memory_config: MemoryConfig object containing llm_model_id
Returns:
OpenAIClient: Initialized LLM client
Raises:
ValueError: If LLM model ID is not configured or client initialization fails
Example: Example:
>>> llm_client = get_llm_client_from_config(memory_config) >>> factory = MemoryClientFactory(db)
>>> llm_client = factory.get_llm_client(model_id)
>>> embedder_client = factory.get_embedder_client(embedding_id)
""" """
if not memory_config.llm_model_id:
raise ValueError( def __init__(self, db: Session):
f"Configuration {memory_config.config_id} has no LLM model configured" from app.services.memory_config_service import MemoryConfigService
) self._config_service = MemoryConfigService(db)
return get_llm_client(str(memory_config.llm_model_id))
def get_llm_client(self, llm_id: str) -> OpenAIClient:
"""Get LLM client by model ID."""
if not llm_id:
raise ValueError("LLM ID is required")
try:
model_config = self._config_service.get_model_config(llm_id)
except Exception as e:
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),
type_=model_config.get("type")
)
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_embedder_client(self, embedding_id: str):
"""Get embedder client by model ID."""
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
if not embedding_id:
raise ValueError("Embedding ID is required")
try:
embedder_config = self._config_service.get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
try:
return OpenAIEmbedderClient(
RedBearModelConfig(
model_name=embedder_config.get("model_name"),
provider=embedder_config.get("provider"),
api_key=embedder_config.get("api_key"),
base_url=embedder_config.get("base_url")
)
)
except Exception as e:
model_name = embedder_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
"""Get reranker client by model ID."""
if not rerank_id:
raise ValueError("Rerank ID is required")
try:
model_config = self._config_service.get_model_config(rerank_id)
except Exception as e:
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),
type_=model_config.get("type")
)
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
"""Get LLM client from MemoryConfig object.
Args:
memory_config: Configuration containing llm_model_id
Returns:
OpenAIClient configured for the LLM model
Raises:
ValueError: If memory_config has no LLM model configured
"""
if not memory_config.llm_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no LLM model configured"
)
return self.get_llm_client(str(memory_config.llm_model_id))
def get_embedder_client_from_config(self, memory_config: "MemoryConfig"):
"""Get embedder client from MemoryConfig object.
Args:
memory_config: Configuration containing embedding_model_id
Returns:
OpenAIEmbedderClient configured for the embedding model
Raises:
ValueError: If memory_config has no embedding model configured
"""
if not memory_config.embedding_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no embedding model configured"
)
return self.get_embedder_client(str(memory_config.embedding_model_id))
def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
"""Get reranker client from MemoryConfig object.
Args:
memory_config: Configuration containing rerank_model_id
Returns:
OpenAIClient configured for the reranker model
Raises:
ValueError: If memory_config has no rerank model configured
"""
if not memory_config.rerank_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no rerank model configured"
)
return self.get_reranker_client(str(memory_config.rerank_model_id))
def get_llm_client(llm_id: str): # Legacy functions for backward compatibility
""" def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient:
Get LLM client by model ID. """Get LLM client from MemoryConfig object.
**LEGACY/TEST METHOD**: Use this function only for: DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead.
- Test/evaluation code where you have a model ID directly
- Legacy code that hasn't been migrated to MemoryConfig yet
For production code with MemoryConfig, use get_llm_client_from_config() instead. This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_llm_client_from_config method directly.
Args: Args:
llm_id: LLM model ID (required) memory_config: Configuration containing llm_model_id
db: Database session
Returns: Returns:
OpenAIClient: Initialized LLM client OpenAIClient configured for the LLM model
Raises: Raises:
ValueError: If llm_id is not provided or client initialization fails ValueError: If memory_config has no LLM model configured
Example:
>>> # For tests/evaluations only
>>> llm_client = get_llm_client("model-uuid-string")
""" """
if not llm_id: return MemoryClientFactory(db).get_llm_client_from_config(memory_config)
raise ValueError("LLM ID is required but was not provided")
try:
model_config = get_model_config(llm_id)
except Exception as e:
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
llm_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
return llm_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_reranker_client(rerank_id: str): def get_llm_client(llm_id: str, db: Session) -> OpenAIClient:
""" """Get LLM client by model ID.
Get an LLM client configured for reranking.
DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_llm_client method directly.
Args: Args:
rerank_id: Reranker model ID (required) llm_id: LLM model ID
db: Database session
Returns: Returns:
OpenAIClient: Initialized client for the reranker model OpenAIClient configured for the LLM model
Raises:
ValueError: If rerank_id is not provided or client initialization fails
""" """
if not rerank_id: return MemoryClientFactory(db).get_llm_client(llm_id)
raise ValueError("Rerank ID is required but was not provided")
try:
model_config = get_model_config(rerank_id)
except Exception as e:
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try: def get_embedder_client(embedding_id: str, db: Session):
reranker_client = OpenAIClient(RedBearModelConfig( """Get embedder client by model ID.
model_name=model_config.get("model_name"),
provider=model_config.get("provider"), DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead.
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url") This function is maintained for backward compatibility during migration to the
),type_=model_config.get("type")) factory pattern. New code should create a MemoryClientFactory instance and use
return reranker_client its get_embedder_client method directly.
except Exception as e:
model_name = model_config.get('model_name', 'unknown') Args:
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e embedding_id: Embedding model ID
db: Database session
Returns:
OpenAIEmbedderClient configured for the embedding model
"""
return MemoryClientFactory(db).get_embedder_client(embedding_id)
def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient:
"""Get reranker client by model ID.
DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_reranker_client method directly.
Args:
rerank_id: Reranker model ID
db: Database session
Returns:
OpenAIClient configured for the reranker model
"""
return MemoryClientFactory(db).get_reranker_client(rerank_id)

View File

@@ -6,11 +6,12 @@
""" """
import logging import logging
from typing import List, Any
import time import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.db import get_db_context
from app.schemas.memory_storage_schema import ConflictResultSchema from app.schemas.memory_storage_schema import ConflictResultSchema
from pydantic import BaseModel from pydantic import BaseModel
@@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]:
冲突记忆列表JSON 数组)。 冲突记忆列表JSON 数组)。
""" """
from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema) rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
messages = [{"role": "user", "content": rendered_prompt}] messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}") print(f"提示词长度: {len(rendered_prompt)}")

View File

@@ -6,11 +6,12 @@
""" """
import logging import logging
from typing import List, Any
import time import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.db import get_db_context
from app.schemas.memory_storage_schema import ReflexionResultSchema from app.schemas.memory_storage_schema import ReflexionResultSchema
from pydantic import BaseModel from pydantic import BaseModel
@@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]:
反思结果列表JSON 数组)。 反思结果列表JSON 数组)。
""" """
from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID) with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema) rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
messages = [{"role": "user", "content": rendered_prompt}] messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}") print(f"提示词长度: {len(rendered_prompt)}")

View File

@@ -5,16 +5,24 @@ This module provides functionality to analyze chunk content and generate insight
""" """
import asyncio import asyncio
from typing import List, Dict, Any
from collections import Counter from collections import Counter
from typing import Any, Dict, List
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger() business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkInsight(BaseModel): class ChunkInsight(BaseModel):
"""Pydantic model for chunk insight.""" """Pydantic model for chunk insight."""
insight: str = Field(..., description="对chunk内容的深度洞察分析") insight: str = Field(..., description="对chunk内容的深度洞察分析")
@@ -40,7 +48,7 @@ async def classify_chunk_domain(chunk: str) -> str:
Domain name Domain name
""" """
try: try:
llm_client = get_llm_client() llm_client = _get_llm_client()
prompt = f"""请将以下文本内容归类到最合适的领域中。 prompt = f"""请将以下文本内容归类到最合适的领域中。
@@ -177,7 +185,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
] ]
# 调用LLM生成洞察 # 调用LLM生成洞察
llm_client = get_llm_client() llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages) response = await llm_client.chat(messages=messages)
insight = response.content.strip() insight = response.content.strip()

View File

@@ -5,15 +5,23 @@ This module provides functionality to summarize chunk content using LLM.
""" """
import asyncio import asyncio
from typing import List, Dict, Any from typing import Any, Dict, List
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger() business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkSummary(BaseModel): class ChunkSummary(BaseModel):
"""Pydantic model for chunk summary.""" """Pydantic model for chunk summary."""
summary: str = Field(..., description="简洁的chunk内容摘要") summary: str = Field(..., description="简洁的chunk内容摘要")
@@ -59,7 +67,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
] ]
# 调用LLM生成摘要 # 调用LLM生成摘要
llm_client = get_llm_client() llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages) response = await llm_client.chat(messages=messages)
summary = response.content.strip() summary = response.content.strip()

View File

@@ -7,14 +7,22 @@ This module provides functionality to extract meaningful tags from chunk content
import asyncio import asyncio
from collections import Counter from collections import Counter
from typing import List, Tuple from typing import List, Tuple
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger() business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ExtractedTags(BaseModel): class ExtractedTags(BaseModel):
"""Pydantic model for extracted tags.""" """Pydantic model for extracted tags."""
tags: List[str] = Field(..., description="从文本中提取的关键标签列表") tags: List[str] = Field(..., description="从文本中提取的关键标签列表")
@@ -56,7 +64,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。" "标签应该是名词或名词短语,能够准确概括文本的核心内容。"
) )
llm_client = get_llm_client() llm_client = _get_llm_client()
# 为每个chunk单独提取标签然后统计频率 # 为每个chunk单独提取标签然后统计频率
all_tags = [] all_tags = []
@@ -151,7 +159,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
] ]
# 调用LLM提取人物形象 # 调用LLM提取人物形象
llm_client = get_llm_client() llm_client = _get_llm_client()
structured_response = await llm_client.response_structured( structured_response = await llm_client.response_structured(
messages=messages, messages=messages,
response_model=ExtractedPersona response_model=ExtractedPersona

View File

@@ -391,6 +391,29 @@ class MemoryConfig:
embedding_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict)
config_version: str = "2.0" config_version: str = "2.0"
# Pipeline config: Deduplication
enable_llm_dedup_blockwise: bool = False
enable_llm_disambiguation: bool = False
deep_retrieval: bool = True
t_type_strict: float = 0.8
t_name_strict: float = 0.8
t_overall: float = 0.8
# Pipeline config: Statement extraction
statement_granularity: int = 2
include_dialogue_context: bool = False
max_dialogue_context_chars: int = 1000
# Pipeline config: Forgetting engine
lambda_time: float = 0.5
lambda_mem: float = 0.5
offset: float = 0.0
# Pipeline config: Pruning
pruning_enabled: bool = False
pruning_scene: Optional[str] = "education"
pruning_threshold: float = 0.5
def __post_init__(self): def __post_init__(self):
"""Validate configuration after initialization.""" """Validate configuration after initialization."""
if not self.config_name or not self.config_name.strip(): if not self.config_name or not self.config_name.strip():

View File

@@ -3,26 +3,26 @@
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。 提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
""" """
import time
import uuid
import json
import asyncio import asyncio
import datetime import datetime
from typing import Dict, Any, Optional, List, AsyncGenerator import json
from langchain.tools import tool import time
from pydantic import BaseModel, Field import uuid
from sqlalchemy.orm import Session from typing import Any, AsyncGenerator, Dict, List, Optional
from sqlalchemy import select
from app.models import AgentConfig, ModelConfig, ModelApiKey
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
from app.core.rag.nlp.search import knowledge_retrieval from langchain.tools import tool
from app.services.langchain_tool_server import Search from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger() logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel): class KnowledgeRetrievalInput(BaseModel):
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
""" """
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}") logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try: try:
memory_content = asyncio.run( from app.db import get_db
MemoryAgentService().read_memory( db = next(get_db())
group_id=end_user_id, try:
message=question, memory_content = asyncio.run(
history=[], MemoryAgentService().read_memory(
search_switch="1", group_id=end_user_id,
config_id=config_id, message=question,
storage_type=storage_type, history=[],
user_rag_memory_id=user_rag_memory_id search_switch="1",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
) )
) finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}') logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
@@ -713,9 +719,9 @@ class DraftRunService:
Raises: Raises:
BusinessException: 当指定的会话不存在时 BusinessException: 当指定的会话不存在时
""" """
from app.services.conversation_service import ConversationService
from app.schemas.conversation_schema import ConversationCreate
from app.models import Conversation as ConversationModel from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db) conversation_service = ConversationService(self.db)

View File

@@ -7,14 +7,15 @@ Classes:
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能 EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
""" """
from typing import Dict, Any, Optional, List
import statistics
import json import json
from pydantic import BaseModel, Field import statistics
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.repositories.neo4j.emotion_repository import EmotionRepository from app.repositories.neo4j.emotion_repository import EmotionRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_business_logger from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_business_logger() logger = get_business_logger()
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
async def generate_emotion_suggestions( async def generate_emotion_suggestions(
self, self,
end_user_id: str, end_user_id: str,
config_id: Optional[int] = None db: Session,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""生成个性化情绪建议 """生成个性化情绪建议
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
Args: Args:
end_user_id: 宿主ID用户组ID end_user_id: 宿主ID用户组ID
config_id: 配置ID可选用于从数据库加载LLM配置 db: 数据库会话
Returns: Returns:
Dict: 包含个性化建议的响应: Dict: 包含个性化建议的响应:
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
- suggestions: 建议列表3-5条 - suggestions: 建议列表3-5条
""" """
try: try:
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}") logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 如果提供了 config_id从数据库加载配置 # 1. 从 end_user_id 获取关联的 memory_config_id
if config_id is not None: llm_client = None
from app.core.memory.utils.config.definitions import reload_configuration_from_database try:
config_loaded = reload_configuration_from_database(config_id) from app.services.memory_agent_service import (
if not config_loaded: get_end_user_connected_config,
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置") )
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=int(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.client_factory import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据 # 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d") health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile) prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM # 7. 调用LLM生成建议使用配置中的LLM
from app.core.memory.utils.llm.llm_utils import get_llm_client if llm_client is None:
llm_client = get_llm_client() # 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式 # 将 prompt 转换为 messages 格式
messages = [ messages = [
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
Returns: Returns:
str: LLM prompt str: LLM prompt
""" """
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt( prompt = await render_emotion_suggestions_prompt(
health_data=health_data, health_data=health_data,

View File

@@ -9,10 +9,12 @@ Classes:
import logging import logging
from typing import Optional from typing import Optional
from app.core.memory.models.emotion_models import EmotionExtraction
from app.models.data_config_model import DataConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.llm_tools.llm_client import LLMClientException from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.models.emotion_models import EmotionExtraction
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.data_config_model import DataConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,7 +52,9 @@ class EmotionExtractionService:
""" """
if self.llm_client is None or model_id: if self.llm_client is None or model_id:
effective_model_id = model_id or self.llm_id effective_model_id = model_id or self.llm_id
self.llm_client = get_llm_client(effective_model_id) with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(effective_model_id)
return self.llm_client return self.llm_client
async def extract_emotion( async def extract_emotion(
@@ -142,7 +146,9 @@ class EmotionExtractionService:
Returns: Returns:
Formatted prompt string for LLM Formatted prompt string for LLM
""" """
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_extraction_prompt,
)
prompt = await render_emotion_extraction_prompt( prompt = await render_emotion_extraction_prompt(
statement=statement, statement=statement,

View File

@@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
@@ -45,8 +45,7 @@ config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions # Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector() _neo4j_connector = Neo4jConnector()
db_gen = get_db()
db = next(db_gen)
class MemoryAgentService: class MemoryAgentService:
"""Service for memory agent operations""" """Service for memory agent operations"""
@@ -55,27 +54,6 @@ class MemoryAgentService:
self.user_locks: Dict[str, Lock] = {} self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock() self.locks_lock = Lock()
def load_memory_config(self, config_id: int) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
def writer_messages_deal(self,messages,start_time,group_id,config_id,message): def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0] countext = re.findall(r'"status": "(.*?)",', messages)[0]
@@ -277,14 +255,17 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str: async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str:
""" """
Process write operation with config_id Process write operation with config_id
Args: Args:
group_id: Group identifier group_id: Group identifier (also used as end_user_id)
message: Message to write message: Message to write
config_id: Configuration ID from database config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns: Returns:
Write operation result status Write operation result status
@@ -292,14 +273,24 @@ class MemoryAgentService:
Raises: Raises:
ValueError: If config loading fails or write operation fails ValueError: If config loading fails or write operation fails
""" """
if config_id==None: # Resolve config_id if None using end_user's connected config
config_id = os.getenv("config_id") if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
import time import time
start_time = time.time() start_time = time.time()
# Load configuration from database only # Load configuration from database only
try: try:
memory_config = self.load_memory_config(config_id) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}") logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -366,6 +357,7 @@ class MemoryAgentService:
history: List[Dict], history: List[Dict],
search_switch: str, search_switch: str,
config_id: str, config_id: str,
db: Session,
storage_type: str, storage_type: str,
user_rag_memory_id: str user_rag_memory_id: str
) -> Dict: ) -> Dict:
@@ -378,11 +370,14 @@ class MemoryAgentService:
- "2": Direct answer based on context - "2": Direct answer based on context
Args: Args:
group_id: Group identifier group_id: Group identifier (also used as end_user_id)
message: User message message: User message
history: Conversation history history: Conversation history
search_switch: Search mode switch search_switch: Search mode switch
config_id: Configuration ID from database config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns: Returns:
Dict with 'answer' and 'intermediate_outputs' keys Dict with 'answer' and 'intermediate_outputs' keys
@@ -394,8 +389,13 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
if config_id==None: # Resolve config_id if None using end_user's connected config
config_id = os.getenv("config_id") if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}") logger.info(f"Read operation for group {group_id} with config_id {config_id}")
@@ -411,7 +411,11 @@ class MemoryAgentService:
with group_lock: with group_lock:
# Step 1: Load configuration from database only # Step 1: Load configuration from database only
try: try:
memory_config = self.load_memory_config(config_id) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}") logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -696,7 +700,11 @@ class MemoryAgentService:
logger.info("Classifying message type") logger.info("Classifying message type")
# Load configuration to get LLM model ID # Load configuration to get LLM model ID
memory_config = self.load_memory_config(config_id) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id) status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}") logger.debug(f"Message type: {status}")
@@ -865,7 +873,8 @@ class MemoryAgentService:
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None, current_user_id: Optional[str] = None,
llm_id: Optional[str] = None llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -877,6 +886,7 @@ class MemoryAgentService:
- end_user_id: 用户ID可选 - end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数 - current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成 - llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式: 返回格式:
{ {
@@ -893,7 +903,7 @@ class MemoryAgentService:
# 1. 根据 end_user_id 获取 end_user_name # 1. 根据 end_user_id 获取 end_user_name
try: try:
if end_user_id: if end_user_id and db:
from app.repositories import end_user_repository from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema from app.schemas.end_user_schema import EndUser as EndUserSchema
@@ -948,7 +958,9 @@ class MemoryAgentService:
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities") logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签 # 使用LLM提取标签
llm_client = get_llm_client(llm_id) with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构 # 定义标签提取的结构
class UserTags(BaseModel): class UserTags(BaseModel):
@@ -1111,3 +1123,68 @@ class MemoryAgentService:
# "error_code": "DOC_PARSE_ERROR", # "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)} # "data": {"error": str(e)}
# } # }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取该应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result

View File

@@ -3,7 +3,6 @@ Memory Configuration Service
Centralized configuration loading and management for memory services. Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService. This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
Database session management is handled internally.
""" """
import time import time
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
invalid_value=config_id, invalid_value=config_id,
) )
return parsed_id return parsed_id
except ValueError as e: except ValueError:
raise InvalidConfigError( raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'", f"Invalid configuration ID format: '{config_id}'",
field_name="config_id", field_name="config_id",
@@ -77,19 +76,29 @@ class MemoryConfigService:
This class provides a single implementation of configuration loading logic This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication. that can be shared across multiple services, eliminating code duplication.
Database session management is handled internally.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
""" """
@staticmethod def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config( def load_memory_config(
self,
config_id: int, config_id: int,
service_name: str = "MemoryConfigService", service_name: str = "MemoryConfigService",
) -> MemoryConfig: ) -> MemoryConfig:
""" """
Load memory configuration from database by config_id. Load memory configuration from database by config_id.
This method manages its own database session internally.
Args: Args:
config_id: Configuration ID from database config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes) service_name: Name of the calling service (for logging purposes)
@@ -100,27 +109,6 @@ class MemoryConfigService:
Raises: Raises:
ConfigurationError: If validation fails ConfigurationError: If validation fails
""" """
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
try:
return MemoryConfigService._load_memory_config_with_db(
config_id=config_id,
db=db,
service_name=service_name,
)
finally:
db.close()
@staticmethod
def _load_memory_config_with_db(
config_id: int,
db: Session,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""Internal method that loads memory configuration with an existing db session."""
start_time = time.time() start_time = time.time()
config_logger.info( config_logger.info(
@@ -137,7 +125,7 @@ class MemoryConfigService:
try: try:
validated_config_id = _validate_config_id(config_id) validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id) result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
if not result: if not result:
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
config_logger.error( config_logger.error(
@@ -160,7 +148,7 @@ class MemoryConfigService:
embedding_uuid = validate_embedding_model( embedding_uuid = validate_embedding_model(
validated_config_id, validated_config_id,
memory_config.embedding_id, memory_config.embedding_id,
db, self.db,
workspace.tenant_id, workspace.tenant_id,
workspace.id, workspace.id,
) )
@@ -169,7 +157,7 @@ class MemoryConfigService:
llm_uuid, llm_name = validate_and_resolve_model_id( llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id, memory_config.llm_id,
"llm", "llm",
db, self.db,
workspace.tenant_id, workspace.tenant_id,
required=True, required=True,
config_id=validated_config_id, config_id=validated_config_id,
@@ -183,7 +171,7 @@ class MemoryConfigService:
rerank_uuid, rerank_name = validate_and_resolve_model_id( rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id, memory_config.rerank_id,
"rerank", "rerank",
db, self.db,
workspace.tenant_id, workspace.tenant_id,
required=False, required=False,
config_id=validated_config_id, config_id=validated_config_id,
@@ -194,7 +182,7 @@ class MemoryConfigService:
embedding_name, _ = validate_model_exists_and_active( embedding_name, _ = validate_model_exists_and_active(
embedding_uuid, embedding_uuid,
"embedding", "embedding",
db, self.db,
workspace.tenant_id, workspace.tenant_id,
config_id=validated_config_id, config_id=validated_config_id,
workspace_id=workspace.id, workspace_id=workspace.id,
@@ -220,6 +208,25 @@ class MemoryConfigService:
reflexion_range=memory_config.reflexion_range or "retrieval", reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time", reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(), loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
) )
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000
@@ -262,3 +269,131 @@ class MemoryConfigService:
raise raise
else: else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}") raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}

View File

@@ -50,27 +50,6 @@ class MemoryStorageService:
def __init__(self): def __init__(self):
logger.info("MemoryStorageService initialized") logger.info("MemoryStorageService initialized")
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryStorageService"
)
async def get_storage_info(self) -> dict: async def get_storage_info(self) -> dict:
""" """
Example wrapper method - retrieves storage information Example wrapper method - retrieves storage information
@@ -293,7 +272,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# Load configuration from database only using centralized manager # Load configuration from database only using centralized manager
try: try:
memory_config = MemoryConfigService.load_memory_config( config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(
config_id=int(cid), config_id=int(cid),
service_name="MemoryStorageService.pilot_run_stream" service_name="MemoryStorageService.pilot_run_stream"
) )
@@ -320,13 +300,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
async def run_pipeline(): async def run_pipeline():
"""在后台执行管线并捕获异常""" """在后台执行管线并捕获异常"""
try: try:
from app.core.memory.main import main as pipeline_main from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True") logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await pipeline_main( await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text, dialogue_text=dialogue_text,
is_pilot_run=True, db=self.db,
progress_callback=progress_callback progress_callback=progress_callback,
) )
logger.info("[PILOT_RUN_STREAM] pipeline_main completed") logger.info("[PILOT_RUN_STREAM] pipeline_main completed")

View File

@@ -0,0 +1,219 @@
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")

View File

@@ -176,7 +176,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"""Celery task to process a read message via MemoryAgentService. """Celery task to process a read message via MemoryAgentService.
Args: Args:
group_id: Group ID for the memory agent group_id: Group ID for the memory agent (also used as end_user_id)
message: User message to process message: User message to process
history: Conversation history history: Conversation history
search_switch: Search switch parameter search_switch: Search switch parameter
@@ -190,9 +190,28 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
""" """
start_time = time.time() start_time = time.time()
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception as e:
# Log but continue - will fail later with proper error
pass
async def _run() -> str: async def _run() -> str:
service = MemoryAgentService() db = next(get_db())
return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id) try:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 # 使用 nest_asyncio 来避免事件循环冲突
@@ -246,7 +265,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
group_id: Group ID for the memory agent group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write message: Message to write
config_id: Optional configuration ID config_id: Optional configuration ID
@@ -258,9 +277,28 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
""" """
start_time = time.time() start_time = time.time()
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception as e:
# Log but continue - will fail later with proper error
pass
async def _run() -> str: async def _run() -> str:
service = MemoryAgentService() db = next(get_db())
return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id) try:
service = MemoryAgentService()
return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try: try:
# 使用 nest_asyncio 来避免事件循环冲突 # 使用 nest_asyncio 来避免事件循环冲突