diff --git a/api/app/celery_app.py b/api/app/celery_app.py index ce7e9300..e431b210 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,9 +1,9 @@ import os from datetime import timedelta from urllib.parse import quote -from celery import Celery + from app.core.config import settings -from app.core.memory.utils.config.definitions import reload_configuration_from_database +from celery import Celery # 创建 Celery 应用实例 # broker: 任务队列(使用 Redis DB 0) @@ -13,7 +13,6 @@ celery_app = Celery( broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", ) -reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True) # 配置使用本地队列,避免与远程 worker 冲突 celery_app.conf.task_default_queue = 'localhost_test_wyl' @@ -22,6 +21,7 @@ celery_app.conf.task_default_routing_key = 'localhost_test_wyl' # macOS 兼容性配置 import platform + if platform.system() == 'Darwin': # macOS # 设置环境变量解决 fork 问题 os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') diff --git a/api/app/controllers/emotion_controller.py b/api/app/controllers/emotion_controller.py index 2ed00c43..144aa281 100644 --- a/api/app/controllers/emotion_controller.py +++ b/api/app/controllers/emotion_controller.py @@ -10,22 +10,21 @@ Routes: POST /emotion/suggestions - 获取个性化情绪建议 """ -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session - -from app.core.response_utils import success, fail from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success from app.dependencies import get_current_user, get_db from app.models.user_model import User -from app.schemas.response_schema import ApiResponse from app.schemas.emotion_schema import ( + EmotionHealthRequest, + EmotionSuggestionsRequest, EmotionTagsRequest, EmotionWordcloudRequest, - EmotionHealthRequest, - EmotionSuggestionsRequest ) +from app.schemas.response_schema import ApiResponse from app.services.emotion_analytics_service import EmotionAnalyticsService -from app.core.logging_config import get_api_logger +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session # 获取API专用日志器 api_logger = get_api_logger() @@ -230,7 +229,7 @@ async def get_emotion_suggestions( # 调用服务层 data = await emotion_service.generate_emotion_suggestions( end_user_id=request.group_id, - config_id=config_id + db=db ) api_logger.info( diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 884ee889..b7da943c 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -163,7 +163,8 @@ async def write_server( result = await memory_agent_service.write_memory( user_input.group_id, user_input.message, - config_id, + config_id, + db, storage_type, user_rag_memory_id ) @@ -280,6 +281,7 @@ async def read_server( user_input.history, user_input.search_switch, config_id, + db, storage_type, user_rag_memory_id ) @@ -548,6 +550,7 @@ async def get_write_task_result( @router.post("/status_type", response_model=ApiResponse) async def status_type( user_input: Write_UserInput, + db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ @@ -561,7 +564,11 @@ async def status_type( """ api_logger.info(f"Status type check requested for group {user_input.group_id}") try: - result = await memory_agent_service.classify_message_type(user_input.message) + result = await memory_agent_service.classify_message_type( + user_input.message, + user_input.config_id, + db + ) return success(data=result) except Exception as e: api_logger.error(f"Message type classification failed: {str(e)}") @@ -636,6 +643,7 @@ async def get_hot_memory_tags_by_user_api( @router.get("/analytics/user_profile", response_model=ApiResponse) async def get_user_profile_api( end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ @@ -659,7 +667,8 @@ async def get_user_profile_api( try: result = await memory_agent_service.get_user_profile( end_user_id=end_user_id, - current_user_id=str(current_user.id) + current_user_id=str(current_user.id), + db=db ) return success(data=result, msg="获取用户详情成功") except Exception as e: @@ -694,4 +703,41 @@ async def get_user_profile_api( # ) # except Exception as e: # api_logger.error(f"API docs retrieval failed: {str(e)}") -# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e)) \ No newline at end of file +# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e)) + + +@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) +async def get_end_user_connected_config( + end_user_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 获取终端用户关联的记忆配置 + + 通过以下流程获取配置: + 1. 根据 end_user_id 获取用户的 app_id + 2. 获取该应用的最新发布版本 + 3. 从发布版本的 config 字段中提取 memory_config_id + + Args: + end_user_id: 终端用户ID + + Returns: + 包含 memory_config_id 和相关信息的响应 + """ + from app.services.memory_agent_service import ( + get_end_user_connected_config as get_config, + ) + + api_logger.info(f"Getting connected config for end_user: {end_user_id}") + + try: + result = get_config(end_user_id, db) + return success(data=result, msg="获取终端用户关联配置成功") + except ValueError as e: + api_logger.warning(f"End user config not found: {str(e)}") + return fail(BizCode.NOT_FOUND, str(e)) + except Exception as e: + api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) \ No newline at end of file diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 8dfa6c50..71b5f6d0 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,22 +1,27 @@ import asyncio import time -from dotenv import load_dotenv -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session -from sqlalchemy import text - from app.core.logging_config import get_api_logger +from app.core.memory.storage_services.reflection_engine.self_reflexion import ( + ReflectionConfig, + ReflectionEngine, +) from app.core.response_utils import success -from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine -from app.dependencies import get_current_user from app.db import get_db +from app.dependencies import get_current_user from app.models.user_model import User from app.repositories.data_config_repository import DataConfigRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService from app.schemas.memory_reflection_schemas import Memory_Reflection +from app.services.memory_reflection_service import ( + MemoryReflectionService, + WorkspaceAppService, +) from app.services.model_service import ModelConfigService +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import text +from sqlalchemy.orm import Session load_dotenv() api_logger = get_api_logger() diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 3c33ad6e..380b660c 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -9,18 +9,19 @@ LangChain Agent 封装 """ import os import time -from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage -from langchain_core.tools import BaseTool -from langchain.agents import create_agent +from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence +from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.core.logging_config import get_business_logger from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task +from langchain.agents import create_agent +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.tools import BaseTool + logger = get_business_logger() @@ -198,10 +199,24 @@ class LangChainAgent: """ message_chat= message start_time = time.time() - if config_id == None: - actual_config_id = os.getenv("config_id") - else: - actual_config_id = config_id + actual_config_id = config_id + # If config_id is None, try to get from end_user's connected config + if actual_config_id is None and end_user_id: + try: + from app.db import get_db + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + db = next(get_db()) + try: + connected_config = get_end_user_connected_config(end_user_id, db) + actual_config_id = connected_config.get("memory_config_id") + except Exception as e: + logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") + finally: + db.close() + except Exception as e: + logger.warning(f"Failed to get db session: {e}") actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') @@ -295,10 +310,24 @@ class LangChainAgent: logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info("=" * 80) message_chat = message - if config_id == None: - actual_config_id = os.getenv("config_id") - else: - actual_config_id = config_id + actual_config_id = config_id + # If config_id is None, try to get from end_user's connected config + if actual_config_id is None and end_user_id: + try: + from app.db import get_db + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + db = next(get_db()) + try: + connected_config = get_end_user_connected_config(end_user_id, db) + actual_config_id = connected_config.get("memory_config_id") + except Exception as e: + logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") + finally: + db.close() + except Exception as e: + logger.warning(f"Failed to get db session: {e}") history_term_memory = await self.term_memory_redis_read(end_user_id) if memory_flag: diff --git a/api/app/core/config.py b/api/app/core/config.py index bf5ff45a..da558ac9 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -1,7 +1,8 @@ -import os import json +import os from pathlib import Path -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional + from dotenv import load_dotenv load_dotenv() @@ -81,6 +82,7 @@ class Settings: VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query") # Langfuse configuration + LANGFUSE_ENABLED: bool = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "") LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "") LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "") @@ -153,9 +155,6 @@ class Settings: # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") - MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json") - MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json") - MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json") # Tool Management Configuration TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") @@ -178,65 +177,6 @@ class Settings: return str(base_path / filename) return str(base_path) - def get_memory_config_path(self, config_file: str = "") -> str: - """ - Get the full path for memory module configuration files. - - Args: - config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE) - - Returns: - Full path to the config file - """ - if not config_file: - config_file = self.MEMORY_CONFIG_FILE - return str(Path(self.MEMORY_CONFIG_DIR) / config_file) - - def load_memory_config(self) -> Dict[str, Any]: - """ - Load memory module configuration from config.json. - - Returns: - Dictionary containing memory configuration - """ - config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE) - try: - with open(config_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}") - return {} - - def load_memory_runtime_config(self) -> Dict[str, Any]: - """ - Load memory module runtime configuration from runtime.json. - - Returns: - Dictionary containing runtime configuration - """ - runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE) - try: - with open(runtime_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}") - return {"selections": {}} - - def load_memory_dbrun_config(self) -> Dict[str, Any]: - """ - Load memory module database run configuration from dbrun.json. - - Returns: - Dictionary containing dbrun configuration - """ - dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE) - try: - with open(dbrun_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}") - return {"selections": {}} - def ensure_memory_output_dir(self) -> None: """ Ensure the memory output directory exists. diff --git a/api/app/core/memory/agent/mcp_server/services/search_service.py b/api/app/core/memory/agent/mcp_server/services/search_service.py index b0a007cd..be96bb64 100644 --- a/api/app/core/memory/agent/mcp_server/services/search_service.py +++ b/api/app/core/memory/agent/mcp_server/services/search_service.py @@ -141,7 +141,7 @@ class SearchService: cleaned_query = self.clean_query(question) try: - # Execute search using embedding_model_id from memory_config + # Execute search using memory_config answer = await run_hybrid_search( query_text=cleaned_query, search_type=search_type, @@ -149,7 +149,7 @@ class SearchService: limit=limit, include=include, output_path=output_path, - embedding_id=str(config.embedding_model_id), + memory_config=config, rerank_alpha=rerank_alpha, ) diff --git a/api/app/core/memory/agent/mcp_server/tools/data_tools.py b/api/app/core/memory/agent/mcp_server/tools/data_tools.py index 22dadd7f..631f7fd7 100644 --- a/api/app/core/memory/agent/mcp_server/tools/data_tools.py +++ b/api/app/core/memory/agent/mcp_server/tools/data_tools.py @@ -13,7 +13,8 @@ from app.core.memory.agent.mcp_server.models.retrieval_models import ( ) from app.core.memory.agent.mcp_server.server import get_context_resource from app.core.memory.agent.utils.write_tools import write -from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from app.schemas.memory_config_schema import MemoryConfig from mcp.server.fastmcp import Context @@ -41,8 +42,10 @@ async def Data_type_differentiation( # Extract services from context template_service = get_context_resource(ctx, 'template_service') - # Get LLM client from memory_config - llm_client = get_llm_client_from_config(memory_config) + # Get LLM client from memory_config using factory pattern + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) # Render template try: diff --git a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py b/api/app/core/memory/agent/mcp_server/tools/problem_tools.py index 892fdbdd..49812e38 100644 --- a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py +++ b/api/app/core/memory/agent/mcp_server/tools/problem_tools.py @@ -16,7 +16,8 @@ from app.core.memory.agent.mcp_server.models.problem_models import ( ) from app.core.memory.agent.mcp_server.server import get_context_resource from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal -from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from app.schemas.memory_config_schema import MemoryConfig from mcp.server.fastmcp import Context @@ -56,7 +57,9 @@ async def Split_The_Problem( session_service = get_context_resource(ctx, "session_service") # Get LLM client from memory_config - llm_client = get_llm_client_from_config(memory_config) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) # Extract user ID from session user_id = session_service.resolve_user_id(sessionid) @@ -190,7 +193,9 @@ async def Problem_Extension( session_service = get_context_resource(ctx, "session_service") # Get LLM client from memory_config - llm_client = get_llm_client_from_config(memory_config) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) # Resolve session ID from usermessages from app.core.memory.agent.utils.messages_tool import Resolve_username diff --git a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py index 8b6b7ae4..6d5012f1 100644 --- a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py @@ -21,8 +21,9 @@ from app.core.memory.agent.utils.messages_tool import ( Resolve_username, Summary_messages_deal, ) -from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context from app.schemas.memory_config_schema import MemoryConfig from dotenv import load_dotenv from mcp.server.fastmcp import Context @@ -66,7 +67,9 @@ async def Summary( session_service = get_context_resource(ctx, "session_service") # Get LLM client from memory_config - llm_client = get_llm_client_from_config(memory_config) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) # Resolve session ID sessionid = Resolve_username(usermessages) @@ -210,7 +213,9 @@ async def Retrieve_Summary( session_service = get_context_resource(ctx, "session_service") # Get LLM client from memory_config - llm_client = get_llm_client_from_config(memory_config) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) # Resolve session ID sessionid = Resolve_username(usermessages) @@ -425,7 +430,9 @@ async def Input_Summary( search_service = get_context_resource(ctx, "search_service") # Get LLM client from memory_config - llm_client = get_llm_client_from_config(memory_config) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) # Resolve session ID sessionid = Resolve_username(usermessages) or "" diff --git a/api/app/core/memory/agent/utils/type_classifier.py b/api/app/core/memory/agent/utils/type_classifier.py index d1b75d43..3e5358bd 100644 --- a/api/app/core/memory/agent/utils/type_classifier.py +++ b/api/app/core/memory/agent/utils/type_classifier.py @@ -1,15 +1,14 @@ """ Type classification utility for distinguishing read/write operations. """ -from jinja2 import Template -from pydantic import BaseModel - +from app.core.config import settings from app.core.logging_config import get_agent_logger, log_prompt_rendering from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from app.core.memory.agent.utils.messages_tool import read_template_file -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.config import settings - +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context +from jinja2 import Template +from pydantic import BaseModel logger = get_agent_logger(__name__) @@ -44,7 +43,9 @@ async def status_typle(messages: str, llm_model_id: str) -> dict: "message": f"Prompt rendering failed: {str(e)}" } - llm_client = get_llm_client(llm_model_id) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(llm_model_id) try: structured = await llm_client.response_structured( diff --git a/api/app/core/memory/agent/utils/verify_tool.py b/api/app/core/memory/agent/utils/verify_tool.py index fe721770..3a74ee25 100644 --- a/api/app/core/memory/agent/utils/verify_tool.py +++ b/api/app/core/memory/agent/utils/verify_tool.py @@ -1,18 +1,19 @@ -from typing import TypedDict, Annotated, List, Any -from langchain_core.messages import AnyMessage -from langgraph.constants import START, END -from langgraph.graph import StateGraph, add_messages import asyncio import json -from dotenv import load_dotenv, find_dotenv import os -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from langchain_core.messages import HumanMessage -from jinja2 import Environment, FileSystemLoader -from app.core.memory.agent.utils.messages_tool import _to_openai_messages -from app.core.memory.utils.llm.llm_utils import get_llm_client +from typing import Annotated, Any, List, TypedDict + # Removed global variable imports - use dependency injection instead from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ +from app.core.memory.agent.utils.messages_tool import _to_openai_messages +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context +from dotenv import find_dotenv, load_dotenv +from jinja2 import Environment, FileSystemLoader +from langchain_core.messages import AnyMessage, HumanMessage +from langgraph.constants import END, START +from langgraph.graph import StateGraph, add_messages load_dotenv(find_dotenv()) @@ -53,7 +54,9 @@ class VerifyTool: async def model_1(self, state: State) -> State: if not self.llm_model_id: raise ValueError("llm_model_id is required but not provided") - llm_client = get_llm_client(self.llm_model_id) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(self.llm_model_id) response_content = await llm_client.chat( messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])] ) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 60259873..f09b35e8 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -13,13 +13,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator ExtractionOrchestrator, ) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - Memory_summary_generation, + memory_summary_generation, ) -from app.core.memory.utils.embedder.embedder_utils import ( - get_embedder_client_from_config, -) -from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time +from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j @@ -67,9 +65,11 @@ async def write( logger.info(f"Chunker strategy: {chunker_strategy}") logger.info(f"Group ID: {group_id}") - # Construct clients from memory_config - llm_client = get_llm_client_from_config(memory_config) - embedder_client = get_embedder_client_from_config(memory_config) + # Construct clients from memory_config using factory pattern with db session + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) + embedder_client = factory.get_embedder_client_from_config(memory_config) logger.info("LLM and embedding clients constructed") # Initialize timing log @@ -100,7 +100,7 @@ async def write( # Step 2: Initialize and run ExtractionOrchestrator step_start = time.time() from app.core.memory.utils.config.config_utils import get_pipeline_config - pipeline_config = get_pipeline_config() + pipeline_config = get_pipeline_config(memory_config) orchestrator = ExtractionOrchestrator( llm_client=llm_client, @@ -155,8 +155,8 @@ async def write( # Step 4: Generate Memory summaries and save to Neo4j step_start = time.time() try: - summaries = await Memory_summary_generation( - chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id + summaries = await memory_summary_generation( + chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client ) try: diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index cfcff994..2aa286ba 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -35,7 +35,9 @@ except NameError: import json from app.core.config import settings -from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context +from app.services.memory_config_service import MemoryConfigService #TODO: Fix this # Default values (previously from definitions.py) @@ -47,11 +49,37 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") -async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]: +async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 """ try: + # Get config_id using get_end_user_connected_config + with get_db_context() as db: + try: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id: + # Use the config_id to get the proper LLM client + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + else: + # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config + # Fallback to default LLM if no config found + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(DEFAULT_LLM_ID) + except Exception as e: + print(f"Failed to get user connected config, using default LLM: {e}") + # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config + # Fallback to default LLM + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(DEFAULT_LLM_ID) # 3. 构建Prompt tag_list_str = ", ".join(tags) @@ -156,8 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u raw_tag_names = [tag for tag, freq in raw_tags_with_freq] # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 - llm_client = get_llm_client(DEFAULT_LLM_ID) - meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client) + meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) final_tags = [] diff --git a/api/app/core/memory/analytics/memory_insight.py b/api/app/core/memory/analytics/memory_insight.py index 35ed466f..06791702 100644 --- a/api/app/core/memory/analytics/memory_insight.py +++ b/api/app/core/memory/analytics/memory_insight.py @@ -18,8 +18,10 @@ if src_path not in sys.path: sys.path.insert(0, src_path) from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService from pydantic import BaseModel, Field #TODO: Fix this @@ -59,7 +61,33 @@ class MemoryInsight: def __init__(self, user_id: str): self.user_id = user_id self.neo4j_connector = Neo4jConnector() - self.llm_client = get_llm_client(DEFAULT_LLM_ID) + + # Get config_id using get_end_user_connected_config + with get_db_context() as db: + try: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(user_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id: + # Use the config_id to get the proper LLM client + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + factory = MemoryClientFactory(db) + self.llm_client = factory.get_llm_client(memory_config.llm_model_id) + else: + # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config + # Fallback to default LLM if no config found + factory = MemoryClientFactory(db) + self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID) + except Exception as e: + print(f"Failed to get user connected config, using default LLM: {e}") + # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config + # Fallback to default LLM + factory = MemoryClientFactory(db) + self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID) async def close(self): """关闭数据库连接。""" diff --git a/api/app/core/memory/analytics/user_summary.py b/api/app/core/memory/analytics/user_summary.py index eb6bc83a..3f4f4a2d 100644 --- a/api/app/core/memory/analytics/user_summary.py +++ b/api/app/core/memory/analytics/user_summary.py @@ -25,8 +25,10 @@ except Exception: pass from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService #TODO: Fix this @@ -47,7 +49,33 @@ class UserSummary: def __init__(self, user_id: str): self.user_id = user_id self.connector = Neo4jConnector() - self.llm = get_llm_client(DEFAULT_LLM_ID) + + # Get config_id using get_end_user_connected_config + with get_db_context() as db: + try: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(user_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id: + # Use the config_id to get the proper LLM client + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + factory = MemoryClientFactory(db) + self.llm = factory.get_llm_client(memory_config.llm_model_id) + else: + # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config + # Fallback to default LLM if no config found + factory = MemoryClientFactory(db) + self.llm = factory.get_llm_client(DEFAULT_LLM_ID) + except Exception as e: + print(f"Failed to get user connected config, using default LLM: {e}") + # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config + # Fallback to default LLM + factory = MemoryClientFactory(db) + self.llm = factory.get_llm_client(DEFAULT_LLM_ID) async def close(self): await self.connector.close() diff --git a/api/app/core/memory/evaluation/extraction_utils.py b/api/app/core/memory/evaluation/extraction_utils.py index b45ea7e4..9afa228c 100644 --- a/api/app/core/memory/evaluation/extraction_utils.py +++ b/api/app/core/memory/evaluation/extraction_utils.py @@ -1,22 +1,34 @@ -import os import asyncio import json -from typing import List, Dict, Any, Optional -from datetime import datetime +import os import re +from datetime import datetime +from typing import Any, Dict, List, Optional from app.core.memory.llm_tools.openai_client import LLMClient -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker -from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID +from app.core.memory.models.message_models import ( + ConversationContext, + ConversationMessage, + DialogData, +) # 使用新的模块化架构 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator +from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( + ExtractionOrchestrator, +) +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( + DialogueChunker, +) +from app.core.memory.utils.config.definitions import ( + SELECTED_CHUNKER_STRATEGY, + SELECTED_EMBEDDING_ID, +) +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context # Import from database module from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j +from app.repositories.neo4j.neo4j_connector import Neo4jConnector # Cypher queries for evaluation # Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py @@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline( llm_available = True try: from app.core.memory.utils.config import definitions as config_defs - llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) except Exception as e: print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}") llm_available = False @@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline( return False # 初始化 embedder 客户端 - from app.core.models.base import RedBearModelConfig - from app.core.memory.utils.config.config_utils import get_embedder_config from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + from app.core.models.base import RedBearModelConfig + from app.services.memory_config_service import MemoryConfigService try: - embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID) + with get_db_context() as db: + embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID) embedder_config = RedBearModelConfig(**embedder_config_dict) embedder_client = OpenAIEmbedderClient(embedder_config) except Exception as e: @@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline( print("[Ingestion] Generating memory summaries...") try: from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - Memory_summary_generation, + memory_summary_generation, ) - from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges + from app.repositories.neo4j.add_nodes import add_memory_summary_nodes - summaries = await Memory_summary_generation( + summaries = await memory_summary_generation( chunked_dialogs=dialog_data_list, llm_client=llm_client, - embedding_id=embedding_name or SELECTED_EMBEDDING_ID + embedder_client=embedder_client ) print(f"[Ingestion] Generated {len(summaries)} memory summaries") except Exception as e: diff --git a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py index 67f41771..4992aa29 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py +++ b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py @@ -15,7 +15,7 @@ import json import os import time from datetime import datetime -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional try: from dotenv import load_dotenv @@ -23,37 +23,38 @@ except ImportError: def load_dotenv(): pass -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config_utils import get_embedder_config -from app.core.memory.utils.definitions import ( - PROJECT_ROOT, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, - SELECTED_EMBEDDING_ID -) -from app.core.memory.utils.llm_utils import get_llm_client +from app.core.memory.client_factory import MemoryClientFactory from app.core.memory.evaluation.common.metrics import ( - f1_score, + avg_context_tokens, bleu1, + f1_score, jaccard, latency_stats, - avg_context_tokens ) from app.core.memory.evaluation.locomo.locomo_metrics import ( + get_category_name, locomo_f1_score, locomo_multi_f1, - get_category_name ) from app.core.memory.evaluation.locomo.locomo_utils import ( - load_locomo_data, extract_conversations, + ingest_conversations_if_needed, + load_locomo_data, resolve_temporal_references, - select_and_format_information, retrieve_relevant_information, - ingest_conversations_if_needed + select_and_format_information, ) +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.memory.utils.definitions import ( + PROJECT_ROOT, + SELECTED_EMBEDDING_ID, + SELECTED_GROUP_ID, + SELECTED_LLM_ID, +) +from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService async def run_locomo_benchmark( @@ -160,10 +161,16 @@ async def run_locomo_benchmark( # Step 3: Initialize clients print("🔧 Initializing clients...") connector = Neo4jConnector() - llm_client = get_llm_client(SELECTED_LLM_ID) + + # Initialize LLM client with database context + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(SELECTED_LLM_ID) # Initialize embedder - cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py index ad51931a..a8fa1820 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ b/api/app/core/memory/evaluation/locomo/locomo_test.py @@ -1,14 +1,16 @@ # file name: check_neo4j_connection_fixed.py import asyncio -import os -import sys import json -import time import math +import os import re +import sys +import time from datetime import datetime, timedelta -from typing import List, Dict, Any +from typing import Any, Dict, List + from dotenv import load_dotenv + # 1 # 添加项目根目录到路径 current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str: # 尝试从 metrics.py 导入基础指标 try: - from common.metrics import f1_score, bleu1, jaccard + from common.metrics import bleu1, f1_score, jaccard print("✅ 从 metrics.py 导入基础指标成功") except ImportError as e: print(f"❌ 从 metrics.py 导入失败: {e}") @@ -111,10 +113,14 @@ try: # 尝试从不同位置导入 try: - from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times + from locomo.qwen_search_eval import ( + _resolve_relative_times, + loc_f1_score, + loc_multi_f1, + ) print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功") except ImportError: - from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times + from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1 print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功") except ImportError as e: @@ -429,13 +435,17 @@ async def run_enhanced_evaluation(): return None # 修正导入路径:使用 app.core.memory.src 前缀 - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - from app.repositories.neo4j.graph_search import search_graph_by_embedding + from app.core.memory.client_factory import MemoryClientFactory from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + from app.core.memory.utils.config.definitions import ( + SELECTED_EMBEDDING_ID, + SELECTED_LLM_ID, + ) from app.core.models.base import RedBearModelConfig - from app.core.memory.utils.llm.llm_utils import get_llm_client - from app.core.memory.utils.config.config_utils import get_embedder_config - from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID + from app.db import get_db_context + from app.repositories.neo4j.graph_search import search_graph_by_embedding + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.services.memory_config_service import MemoryConfigService # 加载数据 # 获取项目根目录 @@ -458,10 +468,14 @@ async def run_enhanced_evaluation(): # 初始化增强监控器 monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6) - llm = get_llm_client(SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm = factory.get_llm_client(SELECTED_LLM_ID) # 初始化embedder - cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) diff --git a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py index cbc0bfbd..e7a95e71 100644 --- a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py @@ -2,10 +2,11 @@ import argparse import asyncio import json import os +import statistics import time from datetime import datetime, timedelta -from typing import List, Dict, Any -import statistics +from typing import Any, Dict, List + try: from dotenv import load_dotenv except Exception: @@ -13,16 +14,31 @@ except Exception: return None import re -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding + +from app.core.memory.client_factory import MemoryClientFactory +from app.core.memory.evaluation.common.metrics import ( + avg_context_tokens, + bleu1, + jaccard, + latency_stats, +) +from app.core.memory.evaluation.common.metrics import f1_score as common_f1 +from app.core.memory.evaluation.extraction_utils import ( + ingest_contexts_via_full_pipeline, +) from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config from app.core.memory.storage_services.search import run_hybrid_search -from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline -from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens +from app.core.memory.utils.config.definitions import ( + PROJECT_ROOT, + SELECTED_EMBEDDING_ID, + SELECTED_GROUP_ID, + SELECTED_LLM_ID, +) +from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService # 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现) @@ -327,9 +343,13 @@ async def run_locomo_eval( await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True) # 使用异步LLM客户端 - llm_client = get_llm_client(SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(SELECTED_LLM_ID) # 初始化embedder用于直接调用 - cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) diff --git a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py index d7cd711b..58652033 100644 --- a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py @@ -2,11 +2,11 @@ import argparse import asyncio import json import os -import time import re import statistics +import time from datetime import datetime, timedelta -from typing import List, Dict, Any +from typing import Any, Dict, List try: from dotenv import load_dotenv @@ -16,6 +16,7 @@ except Exception: # 确保可以找到 src 及项目根路径 import sys + _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR))) _SRC_DIR = os.path.join(_PROJECT_ROOT, "src") @@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT): # 与现有评估脚本保持一致的导入方式 from app.repositories.neo4j.neo4j_connector import Neo4jConnector + try: # 优先从 extraction_utils1 导入 - from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore + from app.core.memory.evaluation.extraction_utils import ( + ingest_contexts_via_full_pipeline, # type: ignore + ) except Exception: ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查 -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.client_factory import MemoryClientFactory +from app.core.memory.evaluation.common.metrics import ( + avg_context_tokens, + jaccard, + latency_stats, +) +from app.core.memory.evaluation.common.metrics import f1_score as common_f1 from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID -from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.memory.utils.config.definitions import ( + PROJECT_ROOT, + SELECTED_EMBEDDING_ID, + SELECTED_LLM_ID, +) +from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.services.memory_config_service import MemoryConfigService + try: from app.core.memory.evaluation.common.metrics import exact_match except Exception: @@ -686,9 +701,13 @@ async def run_longmemeval_test( ) # 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端 - llm_client = get_llm_client(SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(SELECTED_LLM_ID) connector = Neo4jConnector() - cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) @@ -748,10 +767,10 @@ async def run_longmemeval_test( if stmt_text: contexts_all.append(stmt_text) - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) + # for sm in summaries: + # summary_text = str(sm.get("summary", "")).strip() + # if summary_text: + # contexts_all.append(summary_text) # 实体摘要(最多3个) scored = [e for e in entities if e.get("score") is not None] diff --git a/api/app/core/memory/evaluation/longmemeval/test_eval.py b/api/app/core/memory/evaluation/longmemeval/test_eval.py index 550de2d2..a0038260 100644 --- a/api/app/core/memory/evaluation/longmemeval/test_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/test_eval.py @@ -2,11 +2,11 @@ import argparse import asyncio import json import os -import time import re import statistics +import time from datetime import datetime, timedelta -from typing import List, Dict, Any +from typing import Any, Dict, List try: from dotenv import load_dotenv @@ -15,15 +15,26 @@ except Exception: return None # 与现有评估脚本保持一致的导入方式 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config_utils import get_embedder_config -from app.core.memory.utils.llm_utils import get_llm_client +from app.core.memory.client_factory import MemoryClientFactory +from app.core.memory.evaluation.common.metrics import ( + avg_context_tokens, + jaccard, + latency_stats, +) +from app.core.memory.evaluation.common.metrics import f1_score as common_f1 from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID -from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.memory.utils.config.definitions import ( + PROJECT_ROOT, + SELECTED_EMBEDDING_ID, + SELECTED_LLM_ID, +) +from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService + try: from app.core.memory.evaluation.common.metrics import exact_match except Exception: @@ -647,9 +658,13 @@ async def run_longmemeval_test( items = qa_list[start_index:start_index + sample_size] # 初始化组件 - 使用异步LLM客户端 - llm_client = get_llm_client(SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(SELECTED_LLM_ID) connector = Neo4jConnector() - cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) diff --git a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py index f41d8f10..3e6a1216 100644 --- a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py +++ b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py @@ -4,19 +4,35 @@ import json import os import time from datetime import datetime -from typing import List, Dict, Any +from typing import TYPE_CHECKING, Any, Dict, List + +if TYPE_CHECKING: + from app.schemas.memory_config_schema import MemoryConfig + try: from dotenv import load_dotenv except Exception: def load_dotenv(): return None -from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.client_factory import MemoryClientFactory +from app.core.memory.evaluation.common.metrics import ( + avg_context_tokens, + exact_match, + latency_stats, +) +from app.core.memory.evaluation.extraction_utils import ( + ingest_contexts_via_full_pipeline, +) from app.core.memory.storage_services.search import run_hybrid_search -from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline -from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens +from app.core.memory.utils.config.definitions import ( + PROJECT_ROOT, + SELECTED_EMBEDDING_ID, + SELECTED_GROUP_ID, + SELECTED_LLM_ID, +) +from app.db import get_db_context +from app.repositories.neo4j.neo4j_connector import Neo4jConnector def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: @@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any return merged -async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]: +async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: group_id = group_id or SELECTED_GROUP_ID # Load data data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") @@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s await ingest_contexts_via_full_pipeline(contexts, group_id) # LLM client (使用异步调用) - llm_client = get_llm_client(SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(SELECTED_LLM_ID) # Evaluate each item connector = Neo4jConnector() @@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s limit=search_limit, include=["dialogues", "statements", "entities"], output_path=None, + memory_config=memory_config, ) except Exception: results = None @@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip()) # Metrics: F1, BLEU-1, Jaccard; keep exact match for reference correct_flags.append(exact_match(pred, reference)) - from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard + from app.core.memory.evaluation.common.metrics import ( + bleu1, + f1_score, + jaccard, + ) f1s.append(f1_score(str(pred), str(reference))) b1s.append(bleu1(str(pred), str(reference))) jss.append(jaccard(str(pred), str(reference))) diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py index c8d89a4d..ebbe6e7e 100644 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py @@ -2,10 +2,10 @@ import argparse import asyncio import json import os +import re import time from datetime import datetime -from typing import List, Dict, Any -import re +from typing import Any, Dict, List try: from dotenv import load_dotenv @@ -15,6 +15,7 @@ except Exception: # 路径与模块导入保持与现有评估脚本一致 import sys + _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) _SRC_DIR = os.path.join(_PROJECT_ROOT, "src") @@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT): sys.path.insert(0, _p) # 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.core.memory.client_factory import MemoryClientFactory +from app.core.memory.evaluation.common.metrics import ( + avg_context_tokens, + exact_match, + latency_stats, +) from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.memory.utils.config.definitions import ( + PROJECT_ROOT, + SELECTED_EMBEDDING_ID, + SELECTED_GROUP_ID, + SELECTED_LLM_ID, +) from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config_utils import get_embedder_config +from app.db import get_db_context +from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID -from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens try: - from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard + from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard except Exception: # 兜底:简单实现(必要时) def f1_score(pred: str, ref: str) -> float: @@ -226,13 +237,17 @@ async def run_memsciqa_test( items = all_items[start_index:start_index + sample_size] # 初始化 LLM(纯测试:不进行摄入) - llm = get_llm_client(SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm = factory.get_llm_client(SELECTED_LLM_ID) # 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test) connector = Neo4jConnector() embedder = None if search_type in ("embedding", "hybrid"): - cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index bcaa52c2..dce7b495 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -5,18 +5,17 @@ OpenAI LLM 客户端实现 """ import asyncio -from typing import List, Dict, Any import json import logging +from typing import Any, Dict, List -from pydantic import BaseModel -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers import PydanticOutputParser - +from app.core.config import settings +from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException from app.core.models.base import RedBearModelConfig from app.core.models.llm import RedBearLLM -from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException -from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED +from langchain_core.output_parsers import PydanticOutputParser +from langchain_core.prompts import ChatPromptTemplate +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ class OpenAIClient(LLMClient): # 初始化 Langfuse 回调处理器(如果启用) self.langfuse_handler = None - if LANGFUSE_ENABLED: + if settings.LANGFUSE_ENABLED: try: from langfuse.langchain import CallbackHandler self.langfuse_handler = CallbackHandler() diff --git a/api/app/core/memory/main.py b/api/app/core/memory/main.py deleted file mode 100644 index 68bb1de9..00000000 --- a/api/app/core/memory/main.py +++ /dev/null @@ -1,430 +0,0 @@ -""" -MemSci 记忆系统主入口 - 重构版本 - -该模块是重构后的记忆系统主入口,使用新的模块化架构。 -旧版本入口(app/core/memory/src/main.py)已删除。 - -主要功能: -1. 协调整个知识提取流水线 -2. 支持试运行模式和正常运行模式 -3. 使用重构后的 storage_services 模块 -4. 提供统一的配置管理和日志记录 - -作者:Lance77 -日期:2025-11-22 -""" - -# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误 -import os -os.environ["LANGCHAIN_TRACING_V2"] = "false" -os.environ["LANGCHAIN_TRACING"] = "false" -import asyncio -import time -from datetime import datetime -from typing import Optional, Callable, Awaitable -from dotenv import load_dotenv - -# 导入重构后的模块 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData -from app.core.memory.models.variate_config import ExtractionPipelineConfig - -# 导入数据加载函数 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - get_chunked_dialogs_with_preprocessing, - get_chunked_dialogs_from_preprocessed, -) -# 导入配置模块(而不是直接导入变量) -from app.core.memory.utils.config import definitions as config_defs -from app.core.logging_config import get_memory_logger, log_time - -load_dotenv() - -logger = get_memory_logger(__name__) - - - - - -async def main( - # Required configuration parameters (no longer from global variables) - chunker_strategy: str, - group_id: str, - user_id: str, - apply_id: str, - llm_model_id: str, - embedding_model_id: str, - # Optional parameters - dialogue_text: Optional[str] = None, - is_pilot_run: bool = False, - progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None -): - """ - 记忆系统主流程 - 重构版本 (Updated to eliminate global variables) - - 该函数是重构后的主入口,使用新的模块化架构。 - Global variables have been eliminated in favor of explicit parameters. - - Args: - chunker_strategy: Chunking strategy to use (required) - group_id: Group ID for the operation (required) - user_id: User ID for the operation (required) - apply_id: Application ID for the operation (required) - llm_model_id: LLM model ID to use (required) - embedding_model_id: Embedding model ID to use (required) - dialogue_text: 输入的对话文本(可选,用于试运行模式) - is_pilot_run: 是否为试运行模式 - - True: 试运行模式,不保存到 Neo4j - - False: 正常运行模式,保存到 Neo4j - progress_callback: 可选的进度回调函数 - - 类型: Callable[[str, str, Optional[dict]], Awaitable[None]] - - 参数1 (stage): 当前处理阶段标识符 - - 参数2 (message): 人类可读的进度消息 - - 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果 - - 在管线关键点调用以报告进度和结果数据 - - 工作流程: - 1. 初始化客户端和配置 - 2. 加载或准备数据 - 3. 执行知识提取流水线 - 4. 保存结果(正常模式)或输出结果(试运行模式) - """ - print("=" * 60) - print("MemSci 知识提取流水线 - 重构版本") - print("=" * 60) - print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}") - print("Using chunker strategy:", chunker_strategy) - print("Using group ID:", group_id) - print("Using model ID:", llm_model_id) - print("Using embedding model ID:", embedding_model_id) - print("=" * 60) - - # 初始化日志 - log_file = "logs/time.log" - os.makedirs(os.path.dirname(log_file), exist_ok=True) - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(log_file, "a", encoding="utf-8") as f: - f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n") - - pipeline_start = time.time() - - try: - # 步骤 1: 初始化客户端 - logger.info("Initializing clients...") - step_start = time.time() - - llm_client = get_llm_client(llm_model_id) - - # 获取 embedder 配置并转换为 RedBearModelConfig 对象 - from app.core.models.base import RedBearModelConfig - embedder_config_dict = get_embedder_config(embedding_model_id) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - - neo4j_connector = Neo4jConnector() - - log_time("Client Initialization", time.time() - step_start, log_file) - - # 步骤 2: 加载或准备数据 - logger.info("Loading data...") - logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}") - logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}") - logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}") - step_start = time.time() - - if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip(): - # 试运行模式:处理前端传入的对话文本 - logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)") - import re - - # 解析对话文本,支持 "用户:" 和 "AI:" 格式 - pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)" - matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL) - messages = [ - ConversationMessage(role=r, msg=c.strip()) - for r, c in matches if c.strip() - ] - - # 如果没有匹配到格式化的对话,将整个文本作为用户消息 - if not messages: - messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())] - - # 创建对话上下文和对话数据 - context = ConversationContext(msgs=messages) - dialog = DialogData( - context=context, - ref_id="pilot_dialog_1", - group_id=group_id, - user_id=user_id, - apply_id=apply_id, - metadata={"source": "pilot_run", "input_type": "frontend_text"} - ) - - # 进度回调:开始预处理文本 - if progress_callback: - await progress_callback("text_preprocessing", "开始预处理文本...") - - # 对前端传入的对话进行分块处理 - chunked_dialogs = await get_chunked_dialogs_from_preprocessed( - data=[dialog], - chunker_strategy=chunker_strategy, - llm_client=llm_client, - ) - logger.info(f"Processed frontend dialogue text: {len(messages)} messages") - - # 进度回调:输出每个分块的结果 - if progress_callback: - for dialog in chunked_dialogs: - for i, chunk in enumerate(dialog.chunks): - chunk_result = { - "chunk_index": i + 1, - "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, - "full_length": len(chunk.content), - "dialog_id": dialog.id, - "chunker_strategy": chunker_strategy - } - await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) - - # 进度回调:预处理文本完成 - preprocessing_summary = { - "total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs), - "total_dialogs": len(chunked_dialogs), - "chunker_strategy": chunker_strategy - } - await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) - else: - # 正常运行模式:从 testdata.json 文件加载 - logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)") - logger.info("Loading data from testdata.json...") - test_data_path = os.path.join( - os.path.dirname(__file__), "data", "testdata.json" - ) - - if not os.path.exists(test_data_path): - raise FileNotFoundError(f"Test data file not found: {test_data_path}") - - # 进度回调:开始预处理文本 - if progress_callback: - await progress_callback("text_preprocessing", "开始预处理文本...") - - chunked_dialogs = await get_chunked_dialogs_with_preprocessing( - chunker_strategy=chunker_strategy, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, - indices=None, - input_data_path=test_data_path, - llm_client=llm_client, - skip_cleaning=True, - ) - logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json") - - # 进度回调:输出每个分块的结果 - if progress_callback: - for dialog in chunked_dialogs: - for i, chunk in enumerate(dialog.chunks): - chunk_result = { - "chunk_index": i + 1, - "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, - "full_length": len(chunk.content), - "dialog_id": dialog.id, - "chunker_strategy": chunker_strategy - } - await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) - - # 进度回调:预处理文本完成 - preprocessing_summary = { - "total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs), - "total_dialogs": len(chunked_dialogs), - "chunker_strategy": chunker_strategy - } - await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) - - log_time("Data Loading & Chunking", time.time() - step_start, log_file) - - # 步骤 3: 初始化流水线编排器 - logger.info("Initializing extraction orchestrator...") - step_start = time.time() - - # 从 runtime.json 加载配置(已经过数据库覆写) - from app.core.memory.utils.config.config_utils import get_pipeline_config - config = get_pipeline_config() - - logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}") - - orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=neo4j_connector, - config=config, - progress_callback=progress_callback, # 传递进度回调 - embedding_id=embedding_model_id, # 传递嵌入模型ID - ) - - log_time("Orchestrator Initialization", time.time() - step_start, log_file) - - # 步骤 4: 执行知识提取流水线 - logger.info("Running extraction pipeline...") - step_start = time.time() - - - # 进度回调:正在知识抽取 - if progress_callback: - await progress_callback("knowledge_extraction", "正在知识抽取...") - - extraction_result = await orchestrator.run( - dialog_data_list=chunked_dialogs, - is_pilot_run=is_pilot_run, # 传递试运行模式标志 - ) - - # 解包 extraction_result tuple - # extraction_result 是一个包含 7 个元素的 tuple: - # (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes, - # statement_chunk_edges, statement_entity_edges, entity_edges) - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_edges, - ) = extraction_result - - log_time("Extraction Pipeline", time.time() - step_start, log_file) - - # 进度回调:生成结果 - if progress_callback: - await progress_callback("generating_results", "正在生成结果...") - - - # 步骤 5: 保存结果或输出结果 - if is_pilot_run: - logger.info("Pilot run mode: Skipping Neo4j save") - print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。") - print("提取结果已生成,可在相关输出中查看。") - else: - logger.info("Normal mode: Saving to Neo4j...") - step_start = time.time() - - # 创建索引和约束 - try: - from app.repositories.neo4j.create_indexes import ( - create_fulltext_indexes, - create_unique_constraints, - ) - await create_fulltext_indexes() - await create_unique_constraints() - logger.info("Successfully created indexes and constraints") - except Exception as e: - logger.error(f"Error creating indexes/constraints: {e}") - - # 保存数据到 Neo4j - try: - from app.repositories.neo4j.graph_saver import ( - save_dialog_and_statements_to_neo4j, - ) - - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=dialogue_nodes, - chunk_nodes=chunk_nodes, - statement_nodes=statement_nodes, - entity_nodes=entity_nodes, - statement_chunk_edges=statement_chunk_edges, - statement_entity_edges=statement_entity_edges, - entity_edges=entity_edges, - connector=neo4j_connector, - ) - - if success: - logger.info("Successfully saved all data to Neo4j") - print("\n✓ 成功保存所有数据到 Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - print("\n⚠ 部分数据保存到 Neo4j 失败") - except Exception as e: - logger.error(f"Error saving to Neo4j: {e}", exc_info=True) - print(f"\n✗ 保存到 Neo4j 失败: {e}") - - log_time("Neo4j Database Save", time.time() - step_start, log_file) - - # 步骤 6: 生成记忆摘要(可选) - try: - logger.info("Generating memory summaries...") - step_start = time.time() - - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - Memory_summary_generation, - ) - from app.repositories.neo4j.add_nodes import add_memory_summary_nodes - from app.repositories.neo4j.add_edges import ( - add_memory_summary_statement_edges, - ) - - summaries = await Memory_summary_generation( - chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id - ) - - if not is_pilot_run: - # 保存记忆摘要到 Neo4j - ms_connector = Neo4jConnector() - try: - await add_memory_summary_nodes(summaries, ms_connector) - await add_memory_summary_statement_edges(summaries, ms_connector) - finally: - await ms_connector.close() - - log_time("Memory Summary Generation", time.time() - step_start, log_file) - except Exception as e: - logger.error(f"Memory summary step failed: {e}", exc_info=True) - - except Exception as e: - logger.error(f"Pipeline execution failed: {e}", exc_info=True) - print(f"\n✗ 流水线执行失败: {e}") - raise - finally: - # 清理资源 - try: - await neo4j_connector.close() - except Exception: - pass - - # 记录总时间 - total_time = time.time() - pipeline_start - log_time("TOTAL PIPELINE TIME", total_time, log_file) - - # 添加完成标记 - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(log_file, "a", encoding="utf-8") as f: - f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") - - logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") - logger.info(f"Timing details saved to: {log_file}") - - print("\n" + "=" * 60) - print("✓ 流水线执行完成") - print(f"✓ 总耗时: {total_time:.2f} 秒") - print(f"✓ 详细日志: {log_file}") - print("=" * 60) - - -if __name__ == "__main__": - print("⚠️ Warning: This script now requires explicit configuration parameters.") - print("Global variables have been removed. Please provide configuration parameters.") - print("Example usage:") - print(" asyncio.run(main(") - print(" chunker_strategy='RecursiveChunker',") - print(" group_id='your_group_id',") - print(" user_id='your_user_id',") - print(" apply_id='your_apply_id',") - print(" llm_model_id='your_llm_id',") - print(" embedding_model_id='your_embedding_id'") - print(" ))") - - # This will fail because global variables are removed - raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index a8c3f7b0..5977a2d7 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -20,14 +20,14 @@ Classes: MemorySummaryNode: Node representing a memory summary """ -from uuid import uuid4 +import re from datetime import datetime, timezone from typing import List, Optional -from pydantic import BaseModel, Field, field_validator -import re +from uuid import uuid4 -from app.core.memory.utils.data.ontology import TemporalInfo from app.core.memory.utils.alias_utils import validate_aliases +from app.core.memory.utils.data.ontology import TemporalInfo +from pydantic import BaseModel, Field, field_validator def parse_historical_datetime(v): @@ -361,7 +361,7 @@ class ExtractedEntityNode(Node): description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") - fact_summary: str = Field(..., description="Summary of the fact about this entity") + fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") diff --git a/api/app/core/memory/models/triplet_models.py b/api/app/core/memory/models/triplet_models.py index 2325f3bd..b0a062a3 100644 --- a/api/app/core/memory/models/triplet_models.py +++ b/api/app/core/memory/models/triplet_models.py @@ -10,9 +10,10 @@ Classes: """ from typing import List, Optional -from pydantic import BaseModel, Field, ConfigDict from uuid import uuid4 +from pydantic import BaseModel, ConfigDict, Field + class Entity(BaseModel): """Represents an extracted entity from dialogue. diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 597a4789..9353f00e 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -5,7 +5,10 @@ import math import os import time from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from app.schemas.memory_config_schema import MemoryConfig from app.core.logging_config import get_memory_logger from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient @@ -14,16 +17,14 @@ from app.core.memory.models.variate_config import ForgettingEngineConfig from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ( ForgettingEngine, ) -from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config.config_utils import ( - get_embedder_config, get_pipeline_config, ) -from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG from app.core.memory.utils.data.text_utils import extract_plain_query from app.core.memory.utils.data.time_utils import normalize_date_safe from app.core.memory.utils.llm.llm_utils import get_reranker_client from app.core.models.base import RedBearModelConfig +from app.db import get_db_context from app.repositories.neo4j.graph_search import ( search_graph, search_graph_by_chunk_id, @@ -34,6 +35,7 @@ from app.repositories.neo4j.graph_search import ( # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService from dotenv import load_dotenv load_dotenv() @@ -324,229 +326,229 @@ def apply_reranker_placeholder( If config enables reranker, annotate items with a final_score equal to combined_score and keep ordering. This is a no-op reranker to be replaced later. """ - try: - rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})) - except Exception as e: - logger.debug(f"Failed to load reranker config: {e}") - rc = {} - if not rc or not rc.get("enabled", False): - return results + # try: + # rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})) + # except Exception as e: + # logger.debug(f"Failed to load reranker config: {e}") + # rc = {} + # if not rc or not rc.get("enabled", False): + # return results - top_k = int(rc.get("top_k", 100)) - model_name = rc.get("model", "placeholder") + # top_k = int(rc.get("top_k", 100)) + # model_name = rc.get("model", "placeholder") - for cat, items in results.items(): - head = items[:top_k] - for it in head: - base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0) - it["final_score"] = base - it["reranker_model"] = model_name - # Keep overall order by final_score if present, otherwise combined/score - results[cat] = sorted( - items, - key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)), - reverse=True, - ) + # for cat, items in results.items(): + # head = items[:top_k] + # for it in head: + # base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0) + # it["final_score"] = base + # it["reranker_model"] = model_name + # # Keep overall order by final_score if present, otherwise combined/score + # results[cat] = sorted( + # items, + # key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)), + # reverse=True, + # ) return results -async def apply_llm_reranker( - results: Dict[str, List[Dict[str, Any]]], - query_text: str, - reranker_client: Optional[Any] = None, - llm_weight: Optional[float] = None, - top_k: Optional[int] = None, - batch_size: Optional[int] = None, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Apply LLM-based reranking to search results. +# async def apply_llm_reranker( +# results: Dict[str, List[Dict[str, Any]]], +# query_text: str, +# reranker_client: Optional[Any] = None, +# llm_weight: Optional[float] = None, +# top_k: Optional[int] = None, +# batch_size: Optional[int] = None, +# ) -> Dict[str, List[Dict[str, Any]]]: +# """ +# Apply LLM-based reranking to search results. - Args: - results: Search results organized by category - query_text: Original search query - reranker_client: Optional pre-initialized reranker client - llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) - top_k: Maximum number of items to rerank per category - batch_size: Number of items to process concurrently +# Args: +# results: Search results organized by category +# query_text: Original search query +# reranker_client: Optional pre-initialized reranker client +# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) +# top_k: Maximum number of items to rerank per category +# batch_size: Number of items to process concurrently - Returns: - Reranked results with final_score and reranker_model fields - """ - # Load reranker configuration from runtime.json - try: - rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}) - except Exception as e: - logger.debug(f"Failed to load reranker config: {e}") - rc = {} +# Returns: +# Reranked results with final_score and reranker_model fields +# """ +# # Load reranker configuration from runtime.json +# # try: +# # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}) +# # except Exception as e: +# # logger.debug(f"Failed to load reranker config: {e}") +# # rc = {} - # Check if reranking is enabled - enabled = rc.get("enabled", False) - if not enabled: - logger.debug("LLM reranking is disabled in configuration") - return results +# # Check if reranking is enabled +# enabled = rc.get("enabled", False) +# if not enabled: +# logger.debug("LLM reranking is disabled in configuration") +# return results - # Load configuration parameters with defaults - llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) - top_k = top_k if top_k is not None else rc.get("top_k", 20) - batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) +# # Load configuration parameters with defaults +# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) +# top_k = top_k if top_k is not None else rc.get("top_k", 20) +# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) - # Initialize reranker client if not provided - if reranker_client is None: - try: - reranker_client = get_reranker_client() - except Exception as e: - logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") - return results +# # Initialize reranker client if not provided +# if reranker_client is None: +# try: +# reranker_client = get_reranker_client() +# except Exception as e: +# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") +# return results - # Get model name for metadata - model_name = getattr(reranker_client, 'model_name', 'unknown') +# # Get model name for metadata +# model_name = getattr(reranker_client, 'model_name', 'unknown') - # Process each category - reranked_results = {} - for category in ["statements", "chunks", "entities", "summaries"]: - items = results.get(category, []) - if not items: - reranked_results[category] = [] - continue +# # Process each category +# reranked_results = {} +# for category in ["statements", "chunks", "entities", "summaries"]: +# items = results.get(category, []) +# if not items: +# reranked_results[category] = [] +# continue - # Select top K items by combined_score for reranking - sorted_items = sorted( - items, - key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), - reverse=True - ) +# # Select top K items by combined_score for reranking +# sorted_items = sorted( +# items, +# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), +# reverse=True +# ) - top_items = sorted_items[:top_k] - remaining_items = sorted_items[top_k:] +# top_items = sorted_items[:top_k] +# remaining_items = sorted_items[top_k:] - # Extract text content from each item - def extract_text(item: Dict[str, Any]) -> str: - """Extract text content from a result item.""" - # Try different text fields based on category - text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" - return str(text).strip() +# # Extract text content from each item +# def extract_text(item: Dict[str, Any]) -> str: +# """Extract text content from a result item.""" +# # Try different text fields based on category +# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" +# return str(text).strip() - # Batch items for concurrent processing - batches = [] - for i in range(0, len(top_items), batch_size): - batch = top_items[i:i + batch_size] - batches.append(batch) +# # Batch items for concurrent processing +# batches = [] +# for i in range(0, len(top_items), batch_size): +# batch = top_items[i:i + batch_size] +# batches.append(batch) - # Process batches concurrently - async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Process a batch of items with LLM relevance scoring.""" - scored_batch = [] +# # Process batches concurrently +# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +# """Process a batch of items with LLM relevance scoring.""" +# scored_batch = [] - for item in batch: - item_text = extract_text(item) +# for item in batch: +# item_text = extract_text(item) - # Skip items with no text - if not item_text: - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item_copy["final_score"] = combined_score - item_copy["llm_relevance_score"] = 0.0 - item_copy["reranker_model"] = model_name - scored_batch.append(item_copy) - continue +# # Skip items with no text +# if not item_text: +# item_copy = item.copy() +# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) +# item_copy["final_score"] = combined_score +# item_copy["llm_relevance_score"] = 0.0 +# item_copy["reranker_model"] = model_name +# scored_batch.append(item_copy) +# continue - # Create relevance scoring prompt - prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. +# # Create relevance scoring prompt +# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. -Query: {query_text} +# Query: {query_text} -Result: {item_text} +# Result: {item_text} -Respond with only a number between 0.0 and 1.0, where: -- 0.0 means completely irrelevant -- 1.0 means perfectly relevant +# Respond with only a number between 0.0 and 1.0, where: +# - 0.0 means completely irrelevant +# - 1.0 means perfectly relevant -Relevance score:""" +# Relevance score:""" - # Send request to LLM - try: - messages = [{"role": "user", "content": prompt}] - response = await reranker_client.chat(messages) +# # Send request to LLM +# try: +# messages = [{"role": "user", "content": prompt}] +# response = await reranker_client.chat(messages) - # Parse LLM response to extract relevance score - response_text = str(response.content if hasattr(response, 'content') else response).strip() +# # Parse LLM response to extract relevance score +# response_text = str(response.content if hasattr(response, 'content') else response).strip() - # Try to extract a float from the response - try: - # Remove any non-numeric characters except decimal point - import re - score_match = re.search(r'(\d+\.?\d*)', response_text) - if score_match: - llm_score = float(score_match.group(1)) - # Clamp to [0.0, 1.0] - llm_score = max(0.0, min(1.0, llm_score)) - else: - raise ValueError("No numeric score found in response") - except (ValueError, AttributeError) as e: - logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") - llm_score = None +# # Try to extract a float from the response +# try: +# # Remove any non-numeric characters except decimal point +# import re +# score_match = re.search(r'(\d+\.?\d*)', response_text) +# if score_match: +# llm_score = float(score_match.group(1)) +# # Clamp to [0.0, 1.0] +# llm_score = max(0.0, min(1.0, llm_score)) +# else: +# raise ValueError("No numeric score found in response") +# except (ValueError, AttributeError) as e: +# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") +# llm_score = None - # Calculate final score - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) +# # Calculate final score +# item_copy = item.copy() +# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - if llm_score is not None: - final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score - item_copy["llm_relevance_score"] = llm_score - else: - # Use combined_score as fallback - final_score = combined_score - item_copy["llm_relevance_score"] = combined_score +# if llm_score is not None: +# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score +# item_copy["llm_relevance_score"] = llm_score +# else: +# # Use combined_score as fallback +# final_score = combined_score +# item_copy["llm_relevance_score"] = combined_score - item_copy["final_score"] = final_score - item_copy["reranker_model"] = model_name - scored_batch.append(item_copy) - except Exception as e: - logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score") - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item_copy["final_score"] = combined_score - item_copy["llm_relevance_score"] = combined_score - item_copy["reranker_model"] = model_name - scored_batch.append(item_copy) +# item_copy["final_score"] = final_score +# item_copy["reranker_model"] = model_name +# scored_batch.append(item_copy) +# except Exception as e: +# logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score") +# item_copy = item.copy() +# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) +# item_copy["final_score"] = combined_score +# item_copy["llm_relevance_score"] = combined_score +# item_copy["reranker_model"] = model_name +# scored_batch.append(item_copy) - return scored_batch +# return scored_batch - # Process all batches concurrently - try: - batch_tasks = [process_batch(batch) for batch in batches] - batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) +# # Process all batches concurrently +# try: +# batch_tasks = [process_batch(batch) for batch in batches] +# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) - # Merge batch results - scored_items = [] - for result in batch_results: - if isinstance(result, Exception): - logger.warning(f"Batch processing failed: {result}") - continue - scored_items.extend(result) +# # Merge batch results +# scored_items = [] +# for result in batch_results: +# if isinstance(result, Exception): +# logger.warning(f"Batch processing failed: {result}") +# continue +# scored_items.extend(result) - # Add remaining items (not in top K) with their combined_score as final_score - for item in remaining_items: - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item_copy["final_score"] = combined_score - item_copy["reranker_model"] = model_name - scored_items.append(item_copy) +# # Add remaining items (not in top K) with their combined_score as final_score +# for item in remaining_items: +# item_copy = item.copy() +# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) +# item_copy["final_score"] = combined_score +# item_copy["reranker_model"] = model_name +# scored_items.append(item_copy) - # Sort all items by final_score in descending order - scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) - reranked_results[category] = scored_items +# # Sort all items by final_score in descending order +# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) +# reranked_results[category] = scored_items - except Exception as e: - logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") - # Return original items with combined_score as final_score - for item in items: - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item["final_score"] = combined_score - item["reranker_model"] = model_name - reranked_results[category] = items +# except Exception as e: +# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") +# # Return original items with combined_score as final_score +# for item in items: +# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) +# item["final_score"] = combined_score +# item["reranker_model"] = model_name +# reranked_results[category] = items - return reranked_results +# return reranked_results async def run_hybrid_search( @@ -556,7 +558,7 @@ async def run_hybrid_search( limit: int, include: List[str], output_path: str | None, - embedding_id: str, + memory_config: "MemoryConfig", rerank_alpha: float = 0.6, use_forgetting_rerank: bool = False, use_llm_rerank: bool = False, @@ -564,11 +566,14 @@ async def run_hybrid_search( """ Run search with specified type: 'keyword', 'embedding', or 'hybrid' + + Args: + memory_config: MemoryConfig object containing embedding_model_id and config_id """ # Start overall timing search_start_time = time.time() latency_metrics = {} - logger.info(f"using embedding_id:{embedding_id}...") + logger.info(f"using embedding_id:{memory_config.embedding_model_id}...") # Clean and normalize the incoming query before use/logging query_text = extract_plain_query(query_text) @@ -621,7 +626,9 @@ async def run_hybrid_search( # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig config_load_start = time.time() - embedder_config_dict = get_embedder_config(embedding_id) + with get_db_context() as db: + config_service = MemoryConfigService(db) + embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) rb_config = RedBearModelConfig( model_name=embedder_config_dict["model_name"], provider=embedder_config_dict["provider"], @@ -683,7 +690,7 @@ async def run_hybrid_search( if use_forgetting_rerank: # Load forgetting parameters from pipeline config try: - pc = get_pipeline_config() + pc = get_pipeline_config(memory_config) forgetting_cfg = pc.forgetting_engine except Exception as e: logger.debug(f"Failed to load forgetting config, using defaults: {e}") @@ -711,16 +718,16 @@ async def run_hybrid_search( # Apply LLM reranking if enabled llm_rerank_applied = False - if use_llm_rerank: - try: - reranked_results = await apply_llm_reranker( - results=reranked_results, - query_text=query_text, - ) - llm_rerank_applied = True - logger.info("LLM reranking applied successfully") - except Exception as e: - logger.warning(f"LLM reranking failed: {e}, using previous scores") + # if use_llm_rerank: + # try: + # reranked_results = await apply_llm_reranker( + # results=reranked_results, + # query_text=query_text, + # ) + # llm_rerank_applied = True + # logger.info("LLM reranking applied successfully") + # except Exception as e: + # logger.warning(f"LLM reranking failed: {e}, using previous scores") results["reranked_results"] = reranked_results results["combined_summary"] = { @@ -896,90 +903,95 @@ async def search_chunk_by_chunk_id( return {"chunks": chunks} -def main(): - """Main entry point for the hybrid graph search CLI. +# def main(): +# """Main entry point for the hybrid graph search CLI. - Parses command line arguments and executes search with specified parameters. - Supports keyword, embedding, and hybrid search modes. - """ - parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") - parser.add_argument( - "--query", "-q", required=True, help="Free-text query to search" - ) - parser.add_argument( - "--search-type", - "-t", - choices=["keyword", "embedding", "hybrid"], - default="hybrid", - help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" - ) - parser.add_argument( - "--embedding-name", - "-m", - default="openai/nomic-embed-text:v1.5", - help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)", - ) - parser.add_argument( - "--group-id", - "-g", - default=None, - help="Optional group_id to filter results (default: None)", - ) - parser.add_argument( - "--limit", - "-k", - type=int, - default=5, - help="Max number of results per type (default: 5)", - ) - parser.add_argument( - "--include", - "-i", - nargs="+", - default=["statements", "chunks", "entities", "summaries"], - choices=["statements", "chunks", "entities", "summaries"], - help="Which targets to search for embedding search (default: statements chunks entities summaries)" - ) - parser.add_argument( - "--output", - "-o", - default="search_results.json", - help="Path to save the search results JSON (default: search_results.json)", - ) - parser.add_argument( - "--rerank-alpha", - "-a", - type=float, - default=0.6, - help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", - ) - parser.add_argument( - "--forgetting-rerank", - action="store_true", - help="Apply forgetting curve during reranking for hybrid search.", - ) - parser.add_argument( - "--llm-rerank", - action="store_true", - help="Apply LLM-based reranking for hybrid search.", - ) - args = parser.parse_args() +# Parses command line arguments and executes search with specified parameters. +# Supports keyword, embedding, and hybrid search modes. +# """ +# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") +# parser.add_argument( +# "--query", "-q", required=True, help="Free-text query to search" +# ) +# parser.add_argument( +# "--search-type", +# "-t", +# choices=["keyword", "embedding", "hybrid"], +# default="hybrid", +# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" +# ) +# parser.add_argument( +# "--config-id", +# "-c", +# type=int, +# required=True, +# help="Database configuration ID (required)", +# ) +# parser.add_argument( +# "--group-id", +# "-g", +# default=None, +# help="Optional group_id to filter results (default: None)", +# ) +# parser.add_argument( +# "--limit", +# "-k", +# type=int, +# default=5, +# help="Max number of results per type (default: 5)", +# ) +# parser.add_argument( +# "--include", +# "-i", +# nargs="+", +# default=["statements", "chunks", "entities", "summaries"], +# choices=["statements", "chunks", "entities", "summaries"], +# help="Which targets to search for embedding search (default: statements chunks entities summaries)" +# ) +# parser.add_argument( +# "--output", +# "-o", +# default="search_results.json", +# help="Path to save the search results JSON (default: search_results.json)", +# ) +# parser.add_argument( +# "--rerank-alpha", +# "-a", +# type=float, +# default=0.6, +# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", +# ) +# parser.add_argument( +# "--forgetting-rerank", +# action="store_true", +# help="Apply forgetting curve during reranking for hybrid search.", +# ) +# parser.add_argument( +# "--llm-rerank", +# action="store_true", +# help="Apply LLM-based reranking for hybrid search.", +# ) +# args = parser.parse_args() - asyncio.run( - run_hybrid_search( - query_text=args.query, - search_type=args.search_type, - group_id=args.group_id, - limit=args.limit, - include=args.include, - output_path=args.output, - embedding_id=config_defs.SELECTED_EMBEDDING_ID, - rerank_alpha=args.rerank_alpha, - use_forgetting_rerank=args.forgetting_rerank, - use_llm_rerank=args.llm_rerank, - ) - ) +# # Load memory config from database +# from app.services.memory_config_service import MemoryConfigService +# memory_config = MemoryConfigService.load_memory_config(args.config_id) + +# asyncio.run( +# run_hybrid_search( +# query_text=args.query, +# search_type=args.search_type, +# group_id=args.group_id, +# limit=args.limit, +# include=args.include, +# output_path=args.output, +# memory_config=memory_config, +# rerank_alpha=args.rerank_alpha, +# use_forgetting_rerank=args.forgetting_rerank, +# use_llm_rerank=args.llm_rerank, +# ) +# ) -if __name__ == "__main__": - main() +# if __name__ == "__main__": +# main() diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 9088a300..62b656b0 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -1,19 +1,22 @@ """ 去重功能函数 """ -from app.core.memory.models.variate_config import DedupConfig -from typing import List, Dict, Tuple, Any -from app.core.memory.models.graph_models import( - StatementEntityEdge, - EntityEntityEdge, - ExtractedEntityNode -) -import os -from datetime import datetime -import difflib # 提供字符串相似度计算工具 import asyncio +import difflib # 提供字符串相似度计算工具 import importlib +import os import re +from datetime import datetime +from typing import Any, Dict, List, Tuple + +from app.core.memory.models.graph_models import ( + EntityEntityEdge, + ExtractedEntityNode, + StatementEntityEdge, +) +from app.core.memory.models.variate_config import DedupConfig + + # 模块级类型统一工具函数 def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None: """统一实体类型:基于LLM建议或启发式规则选择最合适的类型。 @@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能 statement_entity_edges: List[StatementEntityEdge], entity_entity_edges: List[EntityEntityEdge], id_redirect: Dict[str, str], - config: DedupConfig | None = None, + config: DedupConfig, + llm_client = None, ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]: """ 基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。 @@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能 """ llm_records: List[str] = [] try: - # 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量 - enable_switch = ( - bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise - ) - if not enable_switch: - return deduped_entities, id_redirect, llm_records - # 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值 - _defaults = DedupConfig() - block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size) - block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency) - pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency) - max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds) - - # 动态导入 llm 客户端(修正导入路径) - try: - llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils") - get_llm_client_fn = llm_utils_mod.get_llm_client - except Exception as e: - llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}") + if not bool(config.enable_llm_dedup_blockwise): return deduped_entities, id_redirect, llm_records + # 从配置读取 LLM 迭代参数 + block_size = config.llm_block_size + block_concurrency = config.llm_block_concurrency + pair_concurrency = config.llm_pair_concurrency + max_rounds = config.llm_max_rounds try: llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm") @@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能 llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}") return deduped_entities, id_redirect, llm_records - # 获取 LLM 客户端 - try: - llm_client = get_llm_client_fn() - if llm_client is None: - llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None") - return deduped_entities, id_redirect, llm_records - except Exception as e: - llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}") + # 验证 LLM 客户端 + if llm_client is None: + llm_records.append("[LLM错误] LLM 客户端未提供") return deduped_entities, id_redirect, llm_records llm_redirect, llm_records = await llm_fn( @@ -813,7 +799,8 @@ async def LLM_disamb_decision( statement_entity_edges: List[StatementEntityEdge], entity_entity_edges: List[EntityEntityEdge], id_redirect: Dict[str, str], - config: DedupConfig | None = None, + config: DedupConfig, + llm_client = None, ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]: """ 预消歧阶段:对“同名但类型不同”的实体对调用LLM进行消歧, @@ -824,22 +811,16 @@ async def LLM_disamb_decision( disamb_records: List[str] = [] blocked_pairs: set[tuple[str, str]] = set() try: - enable_switch = ( - config.enable_llm_disambiguation - if config is not None - else DedupConfig().enable_llm_disambiguation - ) - if not bool(enable_switch): + if not bool(config.enable_llm_disambiguation): return deduped_entities, id_redirect, blocked_pairs, disamb_records - from app.core.memory.utils.llm.llm_utils import get_llm_client - from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative - from app.core.memory.utils.config import definitions as config_defs + from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import ( + llm_disambiguate_pairs_iterative, + ) - # 获取 LLM 客户端并验证 - llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) + # 验证 LLM 客户端 if llm_client is None: - disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None") + disamb_records.append("[DISAMB错误] LLM 客户端未提供") return deduped_entities, id_redirect, blocked_pairs, disamb_records merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative( @@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges( report_append: bool = False, report_stage_notes: List[str] | None = None, dedup_config: DedupConfig | None = None, + llm_client = None, ) -> Tuple[ List[ExtractedEntityNode], List[StatementEntityEdge], @@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges( # 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并 deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision( - deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config + deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client ) # 2) 模糊匹配(本地规则) @@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges( if should_trigger_llm: deduped_entities, id_redirect, llm_decision_records = await LLM_decision( - deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config + deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client ) else: llm_decision_records = [] diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index 04aa6cb6..b41f35a4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -10,15 +10,27 @@ from __future__ import annotations -from typing import List, Dict, Any, Tuple from datetime import datetime +from typing import Any, Dict, List, Tuple + +from app.core.memory.models.graph_models import ( + EntityEntityEdge, + ExtractedEntityNode, + StatementEntityEdge, +) +from app.core.memory.models.variate_config import DedupConfig +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明 + _write_dedup_fusion_report, + deduplicate_entities_and_edges, +) +from app.repositories.neo4j.graph_search import ( + get_dedup_candidates_for_entities, # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。 +) # 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互 -from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。 -from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge -from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明 -from app.core.memory.models.variate_config import DedupConfig +from app.repositories.neo4j.neo4j_connector import ( + Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互 +) def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt,用于将任意类型的输入值解析为datetime对象(处理实体节点中的时间字段) @@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 dedup_config: DedupConfig | None = None, + llm_client = None, ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: """ 第二层去重消歧: @@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑 union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes) # 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重) - fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges( + fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges( union_entities, statement_entity_edges, entity_entity_edges, report_stage="第二层去重消歧", report_append=True, dedup_config=dedup_config, + llm_client=llm_client, ) return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index e4857ff3..11845d7d 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -1,23 +1,27 @@ from __future__ import annotations -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple -from app.core.memory.models.variate_config import ExtractionPipelineConfig -from app.core.memory.utils.config.config_utils import get_pipeline_config -from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges -from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.models.graph_models import ( - DialogueNode, ChunkNode, - StatementNode, + DialogueNode, + EntityEntityEdge, ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - EntityEntityEdge, + StatementNode, ) from app.core.memory.models.message_models import DialogData +from app.core.memory.models.variate_config import ExtractionPipelineConfig +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( + deduplicate_entities_and_edges, +) +from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import ( + second_layer_dedup_and_merge_with_neo4j, +) + +# 使用新的仓储层 +from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def dedup_layers_and_merge_and_return( @@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return( statement_entity_edges: List[StatementEntityEdge], entity_entity_edges: List[EntityEntityEdge], dialog_data_list: List[DialogData], - pipeline_config: Optional[ExtractionPipelineConfig] = None, + pipeline_config: ExtractionPipelineConfig, connector: Optional[Neo4jConnector] = None, + llm_client = None, ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return( 返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。 """ - # 默认从 runtime.json 加载管线配置,避免回退到环境变量 + # pipeline_config is required - caller must provide it if pipeline_config is None: - try: - pipeline_config = get_pipeline_config() - except Exception: - pipeline_config = None + raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return") # 先探测 group_id,决定报告写入策略 group_id: Optional[str] = None @@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return( report_stage="第一层去重消歧", report_append=False, dedup_config=(pipeline_config.deduplication if pipeline_config else None), + llm_client=llm_client, ) # 初始化第二层融合结果为第一层结果 @@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return( statement_entity_edges=dedup_statement_entity_edges, entity_entity_edges=dedup_entity_entity_edges, dedup_config=(pipeline_config.deduplication if pipeline_config else None), + llm_client=llm_client, ) else: print("Skip second-layer dedup: missing connector") diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index fa079c97..7c2ed5f4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1253,6 +1253,7 @@ class ExtractionOrchestrator: report_stage="第一层去重消歧(试运行)", report_append=False, dedup_config=self.config.deduplication, + llm_client=self.llm_client, ) # 保存去重消歧的详细记录到实例变量 @@ -1284,6 +1285,7 @@ class ExtractionOrchestrator: dialog_data_list, self.config, self.connector, + llm_client=self.llm_client, ) # 解包返回值 diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 396c1e9e..72f3641e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -5,11 +5,13 @@ """ import asyncio -from typing import List, Dict, Any, Tuple -from app.core.memory.models.message_models import DialogData +from typing import Any, Dict, List, Tuple + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config +from app.core.memory.models.message_models import DialogData from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.services.memory_config_service import MemoryConfigService class EmbeddingGenerator: @@ -21,7 +23,9 @@ class EmbeddingGenerator: Args: embedding_id: 嵌入模型 ID """ - embedder_config = get_embedder_config(embedding_id) + with get_db_context() as db: + config_service = MemoryConfigService(db) + embedder_config = config_service.get_embedder_config(embedding_id) self.embedder_client = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(embedder_config), ) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index ffd4ed12..70c1ceb3 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -1,21 +1,17 @@ -import os import asyncio from datetime import datetime from typing import List, Optional - -from pydantic import Field, field_validator +from uuid import uuid4 from app.core.logging_config import get_memory_logger +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.memory.models.base_response import RobustLLMResponse +from app.core.memory.models.graph_models import MemorySummaryNode from app.core.memory.models.message_models import DialogData +from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt +from pydantic import Field logger = get_memory_logger(__name__) -from app.core.memory.models.graph_models import MemorySummaryNode -from app.core.memory.models.base_response import RobustLLMResponse -from app.core.models.base import RedBearModelConfig -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt -from uuid import uuid4 class MemorySummaryResponse(RobustLLMResponse): @@ -91,22 +87,17 @@ async def _process_chunk_summary( return None -async def Memory_summary_generation( +async def memory_summary_generation( chunked_dialogs: List[DialogData], llm_client, - embedding_id, + embedder_client: OpenAIEmbedderClient, ) -> List[MemorySummaryNode]: """Generate memory summaries per chunk, embed them, and return nodes.""" - embedder_cfg_dict = get_embedder_config(embedding_id) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(embedder_cfg_dict), - ) - # Collect all tasks for parallel processing tasks = [] for dialog in chunked_dialogs: for chunk in dialog.chunks: - tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder)) + tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client)) # Process all chunks in parallel results = await asyncio.gather(*tasks, return_exceptions=False) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index 1e79c339..17f76b17 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -1,17 +1,21 @@ -import os import asyncio import logging -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field +import os from datetime import datetime +from typing import Any, Dict, List, Optional from app.core.memory.models.message_models import DialogData, Statement -#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 -from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo +#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 from app.core.memory.models.variate_config import StatementExtractionConfig +from app.core.memory.utils.data.ontology import ( + LABEL_DEFINITIONS, + RelevenceInfo, + StatementType, + TemporalInfo, +) from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt -from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo +from pydantic import BaseModel, Field logger = logging.getLogger(__name__) diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 6ccec500..6f537916 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -8,21 +8,25 @@ 4. 反思结果应用 - 更新记忆库 """ +import asyncio import json import logging -import asyncio import os import time -from typing import List, Dict, Any, Optional -from enum import Enum import uuid - -from pydantic import BaseModel +from enum import Enum +from typing import Any, Dict, List, Optional from app.core.response_utils import success -from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all -from app.repositories.neo4j.neo4j_update import neo4j_data +from app.repositories.neo4j.cypher_queries import ( + neo4j_query_all, + neo4j_query_part, + neo4j_statement_all, + neo4j_statement_part, +) from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.repositories.neo4j.neo4j_update import neo4j_data +from pydantic import BaseModel # 配置日志 _root_logger = logging.getLogger() @@ -135,14 +139,20 @@ class ReflectionEngine: self.neo4j_connector = Neo4jConnector() if self.llm_client is None: - from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.config import definitions as config_defs - self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + from app.db import get_db_context + with get_db_context() as db: + factory = MemoryClientFactory(db) + self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) elif isinstance(self.llm_client, str): # 如果 llm_client 是字符串(model_id),则用它初始化客户端 - from app.core.memory.utils.llm.llm_utils import get_llm_client + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + from app.db import get_db_context model_id = self.llm_client - self.llm_client = get_llm_client(model_id) + with get_db_context() as db: + factory = MemoryClientFactory(db) + self.llm_client = factory.get_llm_client(model_id) if self.get_data_func is None: from app.core.memory.utils.config.get_data import get_data @@ -154,11 +164,15 @@ class ReflectionEngine: self.get_data_statement = get_data_statement if self.render_evaluate_prompt_func is None: - from app.core.memory.utils.prompt.template_render import render_evaluate_prompt + from app.core.memory.utils.prompt.template_render import ( + render_evaluate_prompt, + ) self.render_evaluate_prompt_func = render_evaluate_prompt if self.render_reflexion_prompt_func is None: - from app.core.memory.utils.prompt.template_render import render_reflexion_prompt + from app.core.memory.utils.prompt.template_render import ( + render_reflexion_prompt, + ) self.render_reflexion_prompt_func = render_reflexion_prompt if self.conflict_schema is None: @@ -170,7 +184,9 @@ class ReflectionEngine: self.reflexion_schema = ReflexionResultSchema if self.update_query is None: - from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT + from app.repositories.neo4j.cypher_queries import ( + UPDATE_STATEMENT_INVALID_AT, + ) self.update_query = UPDATE_STATEMENT_INVALID_AT self._lazy_init_done = True diff --git a/api/app/core/memory/storage_services/search/__init__.py b/api/app/core/memory/storage_services/search/__init__.py index 04a7a4c2..2bec5bf1 100644 --- a/api/app/core/memory/storage_services/search/__init__.py +++ b/api/app/core/memory/storage_services/search/__init__.py @@ -4,10 +4,20 @@ 本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。 """ -from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult -from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy -from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.schemas.memory_config_schema import MemoryConfig + from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy +from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy +from app.core.memory.storage_services.search.search_strategy import ( + SearchResult, + SearchStrategy, +) +from app.core.memory.storage_services.search.semantic_search import ( + SemanticSearchStrategy, +) __all__ = [ "SearchStrategy", @@ -34,7 +44,7 @@ async def run_hybrid_search( include: list[str] | None = None, alpha: float = 0.6, use_forgetting_curve: bool = False, - embedding_id: str | None = None, + memory_config: "MemoryConfig" = None, **kwargs ) -> dict: """运行混合搜索(向后兼容的函数式API) @@ -51,24 +61,26 @@ async def run_hybrid_search( include: 要包含的搜索类别列表 alpha: BM25分数权重(0.0-1.0) use_forgetting_curve: 是否使用遗忘曲线 - embedding_id: 嵌入模型ID + memory_config: MemoryConfig object containing embedding_model_id **kwargs: 其他参数 Returns: dict: 搜索结果字典,格式与旧API兼容 """ - from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.memory.utils.config.config_utils import get_embedder_config - from app.core.memory.utils.config import definitions as config_defs from app.core.models.base import RedBearModelConfig + from app.db import get_db_context + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.services.memory_config_service import MemoryConfigService - # 使用提供的embedding_id或默认值 - emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID + if not memory_config: + raise ValueError("memory_config is required for search") # 初始化客户端 connector = Neo4jConnector() - embedder_config_dict = get_embedder_config(emb_id) + with get_db_context() as db: + config_service = MemoryConfigService(db) + embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) embedder_config = RedBearModelConfig(**embedder_config_dict) embedder_client = OpenAIEmbedderClient(embedder_config) diff --git a/api/app/core/memory/storage_services/search/semantic_search.py b/api/app/core/memory/storage_services/search/semantic_search.py index 363ff1aa..b20f90a5 100644 --- a/api/app/core/memory/storage_services/search/semantic_search.py +++ b/api/app/core/memory/storage_services/search/semantic_search.py @@ -5,15 +5,20 @@ 使用余弦相似度进行语义匹配。 """ -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional + from app.core.logging_config import get_memory_logger -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult -from app.repositories.neo4j.graph_search import search_graph_by_embedding from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config +from app.core.memory.storage_services.search.search_strategy import ( + SearchResult, + SearchStrategy, +) from app.core.memory.utils.config import definitions as config_defs from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.repositories.neo4j.graph_search import search_graph_by_embedding +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.services.memory_config_service import MemoryConfigService logger = get_memory_logger(__name__) @@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy): """ try: # 从数据库读取嵌入器配置 - embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) + with get_db_context() as db: + config_service = MemoryConfigService(db) + embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) rb_config = RedBearModelConfig( model_name=embedder_config_dict["model_name"], provider=embedder_config_dict["provider"], diff --git a/api/app/core/memory/utils/config/__init__.py b/api/app/core/memory/utils/config/__init__.py index c2a8c6ca..f69c13a2 100644 --- a/api/app/core/memory/utils/config/__init__.py +++ b/api/app/core/memory/utils/config/__init__.py @@ -9,7 +9,6 @@ from .config_utils import ( get_chunker_config, get_embedder_config, get_model_config, - get_neo4j_config, get_picture_config, get_pipeline_config, get_pruning_config, @@ -41,7 +40,6 @@ __all__ = [ # config_utils "get_model_config", "get_embedder_config", - "get_neo4j_config", "get_chunker_config", "get_pipeline_config", "get_pruning_config", diff --git a/api/app/core/memory/utils/config/config_utils.py b/api/app/core/memory/utils/config/config_utils.py index b05e176c..7edb2a09 100644 --- a/api/app/core/memory/utils/config/config_utils.py +++ b/api/app/core/memory/utils/config/config_utils.py @@ -1,90 +1,74 @@ -from app.core.memory.models.variate_config import ( - DedupConfig, - ExtractionPipelineConfig, - ForgettingEngineConfig, - StatementExtractionConfig, -) -from app.core.memory.utils.config.definitions import CONFIG -from app.db import get_db -from app.models.models_model import ModelApiKey -from app.services.model_service import ModelConfigService -from fastapi import status -from fastapi.exceptions import HTTPException -from sqlalchemy.orm import Session +""" +Configuration utilities - Backward compatibility layer + +DEPRECATED: These functions now require a db session parameter. +New code should use MemoryConfigService(db) instance directly. + +For functions that don't require db (get_pipeline_config, get_pruning_config), +they are still re-exported here. +""" + +import warnings + +from app.services.memory_config_service import MemoryConfigService + +# These functions don't require db - safe to re-export as static methods +get_pipeline_config = MemoryConfigService.get_pipeline_config +get_pruning_config = MemoryConfigService.get_pruning_config -def get_model_config(model_id: str, db: Session | None = None) -> dict: +def get_model_config(model_id: str, db=None): + """DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly.""" if db is None: - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) # 取到真正的 Session + raise ValueError( + "get_model_config now requires a db session. " + "Use MemoryConfigService(db).get_model_config(model_id) directly." + ) + return MemoryConfigService(db).get_model_config(model_id) - config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - if not config: - print(f"模型ID {model_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - apiConfig: ModelApiKey = config.api_keys[0] - - # 从环境变量读取超时和重试配置 - from app.core.config import settings - - model_config = { - "model_name": apiConfig.model_name, - "provider": apiConfig.provider, - "api_key": apiConfig.api_key, - "base_url": apiConfig.api_base, - "model_config_id":apiConfig.model_config_id, - "type": config.type, - # 添加超时和重试配置,避免 LLM 请求超时 - "timeout": settings.LLM_TIMEOUT, # 从环境变量读取,默认120秒 - "max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取,默认2次 - } - # 写入model_config.log文件中 - with open("logs/model_config.log", "a", encoding="utf-8") as f: - f.write(f"模型ID: {model_id}\n") - f.write(f"模型配置信息:\n{model_config}\n") - f.write("=============================\n\n") - return model_config -def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict: +def get_embedder_config(embedding_id: str, db=None): + """DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly.""" if db is None: - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) # 取到真正的 Session + raise ValueError( + "get_embedder_config now requires a db session. " + "Use MemoryConfigService(db).get_embedder_config(embedding_id) directly." + ) + return MemoryConfigService(db).get_embedder_config(embedding_id) - config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id) - if not config: - print(f"嵌入模型ID {embedding_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") - apiConfig: ModelApiKey = config.api_keys[0] - model_config = { - "model_name": apiConfig.model_name, - "provider": apiConfig.provider, - "api_key": apiConfig.api_key, - "base_url": apiConfig.api_base, - "model_config_id":apiConfig.model_config_id, - # Ensure required field for RedBearModelConfig validation - "type": config.type, - # 添加超时和重试配置,避免嵌入服务请求超时 - "timeout": 120.0, # 嵌入服务超时时间(秒) - "max_retries": 5, # 最大重试次数 - } - # 写入embedder_config.log文件中 - with open("logs/embedder_config.log", "a", encoding="utf-8") as f: - f.write(f"嵌入模型ID: {embedding_id}\n") - f.write(f"嵌入模型配置信息:\n{model_config}\n") - f.write("=============================\n\n") - return model_config -def get_neo4j_config() -> dict: - """Retrieves the Neo4j configuration from the config file.""" - return CONFIG.get("neo4j", {}) def get_picture_config(llm_name: str) -> dict: - """Retrieves the configuration for a specific model from the config file.""" + """Retrieves the configuration for a specific model from the config file. + + .. deprecated:: + This function is deprecated and will be removed in a future version. + Use database-backed model configuration instead. + """ + warnings.warn( + "get_picture_config is deprecated and will be removed in a future version. " + "Use database-backed model configuration instead.", + DeprecationWarning, + stacklevel=2 + ) for model_config in CONFIG.get("picture_recognition", []): if model_config["llm_name"] == llm_name: return model_config raise ValueError(f"Model '{llm_name}' not found in config.json") + + def get_voice_config(llm_name: str) -> dict: - """Retrieves the configuration for a specific model from the config file.""" + """Retrieves the configuration for a specific model from the config file. + + .. deprecated:: + This function is deprecated and will be removed in a future version. + Use database-backed model configuration instead. + """ + warnings.warn( + "get_voice_config is deprecated and will be removed in a future version. " + "Use database-backed model configuration instead.", + DeprecationWarning, + stacklevel=2 + ) for model_config in CONFIG.get("voice_recognition", []): if model_config["llm_name"] == llm_name: return model_config @@ -92,19 +76,8 @@ def get_voice_config(llm_name: str) -> dict: def get_chunker_config(chunker_strategy: str) -> dict: - """Retrieves the configuration for a specific chunker strategy. + """Retrieves the configuration for a specific chunker strategy.""" - Enhancements: - - Supports default configs for `LLMChunker` and `HybridChunker` if not present. - - Falls back to the first available chunker config when the requested one is missing. - """ - # 1) Try to find exact match in config - chunker_list = CONFIG.get("chunker_list", []) - for chunker_config in chunker_list: - if chunker_config.get("chunker_strategy") == chunker_strategy: - return chunker_config - - # 2) Provide sane defaults for newer strategies default_configs = { "RecursiveChunker": { "chunker_strategy": "RecursiveChunker", @@ -112,7 +85,6 @@ def get_chunker_config(chunker_strategy: str) -> dict: "chunk_size": 512, "min_characters_per_chunk": 50 }, - "LLMChunker": { "chunker_strategy": "LLMChunker", "embedding_model": "BAAI/bge-m3", @@ -137,127 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict: if chunker_strategy in default_configs: return default_configs[chunker_strategy] - # 3) Fallback: use first available config but tag with requested strategy - if chunker_list: - fallback = chunker_list[0].copy() - fallback["chunker_strategy"] = chunker_strategy - # Non-fatal notice for visibility in logs if any - print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'") - return fallback - - # 4) If no configs available at all raise ValueError( - f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available" + f"Chunker '{chunker_strategy}' not found " ) - -#TODO: Fix this - -def get_pipeline_config( - config_id: int, - db: Session | None = None, -) -> ExtractionPipelineConfig: - """Build ExtractionPipelineConfig from database. - - Args: - config_id: Database configuration ID (required). Loads pipeline - settings from the DataConfig table. - db: Optional database session. If not provided, a new session - will be created. - - Returns: - ExtractionPipelineConfig with deduplication, statement extraction, - and forgetting engine settings loaded from database. - - Raises: - ValueError: If config_id not found in database. - """ - from app.repositories.data_config_repository import DataConfigRepository - - # Load from database - if db is None: - db_gen = get_db() - db = next(db_gen) - - db_config = DataConfigRepository.get_by_id(db, config_id) - if db_config is None: - raise ValueError(f"Configuration {config_id} not found in database") - - # Build DedupConfig from database - dedup_kwargs = { - "enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False, - "enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False, - } - - # Fuzzy thresholds - if db_config.t_name_strict is not None: - dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict - if db_config.t_type_strict is not None: - dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict - if db_config.t_overall is not None: - dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall - - dedup_config = DedupConfig(**dedup_kwargs) - - # Build StatementExtractionConfig from database - stmt_kwargs = {} - if db_config.statement_granularity is not None: - stmt_kwargs["statement_granularity"] = db_config.statement_granularity - if db_config.include_dialogue_context is not None: - stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context) - if db_config.max_context is not None: - stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context - - stmt_config = StatementExtractionConfig(**stmt_kwargs) - - # Build ForgettingEngineConfig from database - forget_kwargs = {} - if db_config.offset is not None: - forget_kwargs["offset"] = db_config.offset - if db_config.lambda_time is not None: - forget_kwargs["lambda_time"] = db_config.lambda_time - if db_config.lambda_mem is not None: - forget_kwargs["lambda_mem"] = db_config.lambda_mem - - forget_config = ForgettingEngineConfig(**forget_kwargs) - - return ExtractionPipelineConfig( - statement_extraction=stmt_config, - deduplication=dedup_config, - forgetting_engine=forget_config, - ) - - -def get_pruning_config( - config_id: int, - db: Session | None = None, -) -> dict: - """Retrieve semantic pruning config from database. - - Args: - config_id: Database configuration ID (required). - db: Optional database session. - - Returns: - Dict suitable for PruningConfig.model_validate with keys: - - pruning_switch: bool - - pruning_scene: str ("education" | "online_service" | "outbound") - - pruning_threshold: float (0-0.9) - - Raises: - ValueError: If config_id not found in database. - """ - from app.repositories.data_config_repository import DataConfigRepository - - if db is None: - db_gen = get_db() - db = next(db_gen) - - db_config = DataConfigRepository.get_by_id(db, config_id) - if db_config is None: - raise ValueError(f"Configuration {config_id} not found in database") - - return { - "pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False, - "pruning_scene": db_config.pruning_scene or "education", - "pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5, - } diff --git a/api/app/core/memory/utils/config/definitions.py b/api/app/core/memory/utils/config/definitions.py index cc1aef66..fc07c2cc 100644 --- a/api/app/core/memory/utils/config/definitions.py +++ b/api/app/core/memory/utils/config/definitions.py @@ -1,269 +1,268 @@ -""" -配置加载模块 - DEPRECATED +# """ +# 配置加载模块 - DEPRECATED -⚠️ DEPRECATION NOTICE ⚠️ -This module is deprecated and will be removed in a future version. -Global configuration variables have been eliminated in favor of dependency injection. +# ⚠️ DEPRECATION NOTICE ⚠️ +# This module is deprecated and will be removed in a future version. +# Global configuration variables have been eliminated in favor of dependency injection. -Use the new MemoryConfig system instead: -- app.core.memory_config.config.MemoryConfig for configuration objects -- app.services.memory_agent_service.MemoryAgentService.load_memory_config() -- app.services.memory_storage_service.MemoryStorageService.load_memory_config() +# Use the new MemoryConfig system instead: +# - app.schemas.memory_config_schema.MemoryConfig for configuration objects +# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id) -阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED -阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED -阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED -""" -import json -import os -import threading -from datetime import datetime, timedelta -from typing import Any, Dict, Optional +# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED +# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED +# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED +# """ +# import json +# import os +# import threading +# from datetime import datetime, timedelta +# from typing import Any, Dict, Optional -#TODO: Fix this +# #TODO: Fix this -try: - from dotenv import load_dotenv - load_dotenv() -except Exception: - pass +# try: +# from dotenv import load_dotenv +# load_dotenv() +# except Exception: +# pass -# Import unified configuration system -try: - from app.core.config import settings - USE_UNIFIED_CONFIG = True -except ImportError: - USE_UNIFIED_CONFIG = False - settings = None +# # Import unified configuration system +# try: +# from app.core.config import settings +# USE_UNIFIED_CONFIG = True +# except ImportError: +# USE_UNIFIED_CONFIG = False +# settings = None -# PROJECT_ROOT 应该指向 app/core/memory/ 目录 -# __file__ = app/core/memory/utils/config/definitions.py -# os.path.dirname(__file__) = app/core/memory/utils/config -# os.path.dirname(...) = app/core/memory/utils -# os.path.dirname(...) = app/core/memory -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# # PROJECT_ROOT 应该指向 app/core/memory/ 目录 +# # __file__ = app/core/memory/utils/config/definitions.py +# # os.path.dirname(__file__) = app/core/memory/utils/config +# # os.path.dirname(...) = app/core/memory/utils +# # os.path.dirname(...) = app/core/memory +# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# DEPRECATED: Global configuration lock removed -# Use MemoryConfig objects with dependency injection instead +# # DEPRECATED: Global configuration lock removed +# # Use MemoryConfig objects with dependency injection instead -# DEPRECATED: Legacy config.json loading removed -# Use MemoryConfig objects with dependency injection instead -CONFIG = {} +# # DEPRECATED: Legacy config.json loading removed +# # Use MemoryConfig objects with dependency injection instead +# CONFIG = {} -DEFAULT_VALUES = { - "llm_name": "openai/qwen-plus", - "embedding_name": "openai/nomic-embed-text:v1.5", - "chunker_strategy": "RecursiveChunker", - "group_id": "group_123", - "user_id": "default_user", - "apply_id": "default_apply", - "llm_agent_name": "openai/qwen-plus", - "llm_verify_name": "openai/qwen-plus", - "llm_image_recognition": "openai/qwen-plus", - "llm_voice_recognition": "openai/qwen-plus", - "prompt_level": "DEBUG", - "reflexion_iteration_period": "3", - "reflexion_range": "retrieval", - "reflexion_baseline": "TIME", -} +# DEFAULT_VALUES = { +# "llm_name": "openai/qwen-plus", +# "embedding_name": "openai/nomic-embed-text:v1.5", +# "chunker_strategy": "RecursiveChunker", +# "group_id": "group_123", +# "user_id": "default_user", +# "apply_id": "default_apply", +# "llm_agent_name": "openai/qwen-plus", +# "llm_verify_name": "openai/qwen-plus", +# "llm_image_recognition": "openai/qwen-plus", +# "llm_voice_recognition": "openai/qwen-plus", +# "prompt_level": "DEBUG", +# "reflexion_iteration_period": "3", +# "reflexion_range": "retrieval", +# "reflexion_baseline": "TIME", +# } -# DEPRECATED: Legacy global variables for backward compatibility only -# These will be removed in a future version -# Use MemoryConfig objects with dependency injection instead -LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" -SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"]) +# # DEPRECATED: Legacy global variables for backward compatibility only +# # These will be removed in a future version +# # Use MemoryConfig objects with dependency injection instead +# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" +# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"]) -# 阶段 1: 从 runtime.json 加载配置(路径 A) -def _load_from_runtime_json() -> Dict[str, Any]: - """ - DEPRECATED: Legacy runtime.json loading +# # 阶段 1: 从 runtime.json 加载配置(路径 A) +# def _load_from_runtime_json() -> Dict[str, Any]: +# """ +# DEPRECATED: Legacy runtime.json loading - ⚠️ This function is deprecated and will be removed in a future version. - Use MemoryConfig objects with dependency injection instead. +# ⚠️ This function is deprecated and will be removed in a future version. +# Use MemoryConfig objects with dependency injection instead. - Returns: - Dict[str, Any]: Empty configuration (legacy support only) - """ - import warnings - warnings.warn( - "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) - return {"selections": {}} +# Returns: +# Dict[str, Any]: Empty configuration (legacy support only) +# """ +# import warnings +# warnings.warn( +# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) +# return {"selections": {}} -# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器 -# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 -# 保留此函数仅为向后兼容 -def _load_from_database() -> Optional[Dict[str, Any]]: - """ - DEPRECATED: Legacy database configuration loading +# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器 +# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 +# # 保留此函数仅为向后兼容 +# def _load_from_database() -> Optional[Dict[str, Any]]: +# """ +# DEPRECATED: Legacy database configuration loading - ⚠️ This function is deprecated and will be removed in a future version. - Use MemoryConfig objects with dependency injection instead. +# ⚠️ This function is deprecated and will be removed in a future version. +# Use MemoryConfig objects with dependency injection instead. - Returns: - Optional[Dict[str, Any]]: None (deprecated functionality) - """ - import warnings - warnings.warn( - "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) - return None +# Returns: +# Optional[Dict[str, Any]]: None (deprecated functionality) +# """ +# import warnings +# warnings.warn( +# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) +# return None -# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED -def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: - """ - DEPRECATED: 将运行时配置暴露为全局常量供项目使用 +# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED +# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: +# """ +# DEPRECATED: 将运行时配置暴露为全局常量供项目使用 - ⚠️ This function is deprecated and will be removed in a future version. - Global configuration variables have been eliminated in favor of dependency injection. +# ⚠️ This function is deprecated and will be removed in a future version. +# Global configuration variables have been eliminated in favor of dependency injection. - Use the new MemoryConfig system instead: - - app.core.memory_config.config.MemoryConfig for configuration objects - - Pass configuration objects as parameters instead of using global variables +# Use the new MemoryConfig system instead: +# - app.core.memory_config.config.MemoryConfig for configuration objects +# - Pass configuration objects as parameters instead of using global variables - Args: - runtime_cfg: 运行时配置字典 - """ - import warnings - warnings.warn( - "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) +# Args: +# runtime_cfg: 运行时配置字典 +# """ +# import warnings +# warnings.warn( +# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) - # Keep minimal global state for backward compatibility only - # These will be removed in a future version - global RUNTIME_CONFIG, SELECTIONS +# # Keep minimal global state for backward compatibility only +# # These will be removed in a future version +# global RUNTIME_CONFIG, SELECTIONS - RUNTIME_CONFIG = runtime_cfg - SELECTIONS = RUNTIME_CONFIG.get("selections", {}) +# RUNTIME_CONFIG = runtime_cfg +# SELECTIONS = RUNTIME_CONFIG.get("selections", {}) - # All other global variables have been removed - # Use MemoryConfig objects instead +# # All other global variables have been removed +# # Use MemoryConfig objects instead -# 初始化:使用统一配置加载器 -def _initialize_configuration() -> None: - """ - DEPRECATED: Legacy configuration initialization +# # 初始化:使用统一配置加载器 +# def _initialize_configuration() -> None: +# """ +# DEPRECATED: Legacy configuration initialization - ⚠️ This function is deprecated and will be removed in a future version. - Use MemoryConfig objects with dependency injection instead. - """ - import warnings - warnings.warn( - "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) - # Initialize with empty configuration for backward compatibility - _expose_runtime_constants({"selections": {}}) +# ⚠️ This function is deprecated and will be removed in a future version. +# Use MemoryConfig objects with dependency injection instead. +# """ +# import warnings +# warnings.warn( +# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) +# # Initialize with empty configuration for backward compatibility +# _expose_runtime_constants({"selections": {}}) -# 模块加载时自动初始化配置 -_initialize_configuration() +# # 模块加载时自动初始化配置 +# _initialize_configuration() -# DEPRECATED: Global variables removed -# These variables have been eliminated in favor of dependency injection -# Use MemoryConfig objects instead of accessing global variables +# # DEPRECATED: Global variables removed +# # These variables have been eliminated in favor of dependency injection +# # Use MemoryConfig objects instead of accessing global variables -# 公共 API:动态重新加载配置 -def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool: - """ - DEPRECATED: Legacy configuration reloading +# # 公共 API:动态重新加载配置 +# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool: +# """ +# DEPRECATED: Legacy configuration reloading - ⚠️ This function is deprecated and will be removed in a future version. - Use MemoryConfig objects with dependency injection instead. +# ⚠️ This function is deprecated and will be removed in a future version. +# Use MemoryConfig objects with dependency injection instead. - For new code, use: - - app.services.memory_agent_service.MemoryAgentService.load_memory_config() - - app.services.memory_storage_service.MemoryStorageService.load_memory_config() +# For new code, use: +# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() +# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - Args: - config_id: Configuration ID (deprecated) - force_reload: Force reload flag (deprecated) +# Args: +# config_id: Configuration ID (deprecated) +# force_reload: Force reload flag (deprecated) - Returns: - bool: Always returns False (deprecated functionality) - """ - import logging - import warnings +# Returns: +# bool: Always returns False (deprecated functionality) +# """ +# import logging +# import warnings - logger = logging.getLogger(__name__) +# logger = logging.getLogger(__name__) - warnings.warn( - "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) +# warnings.warn( +# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) - logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. " - "Use MemoryConfig objects with dependency injection instead.") +# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. " +# "Use MemoryConfig objects with dependency injection instead.") - return False +# return False -def get_current_config_id() -> Optional[str]: - """ - DEPRECATED: Legacy config ID retrieval +# def get_current_config_id() -> Optional[str]: +# """ +# DEPRECATED: Legacy config ID retrieval - ⚠️ This function is deprecated and will be removed in a future version. - Use MemoryConfig objects with dependency injection instead. +# ⚠️ This function is deprecated and will be removed in a future version. +# Use MemoryConfig objects with dependency injection instead. - Returns: - Optional[str]: None (deprecated functionality) - """ - import warnings - warnings.warn( - "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) - return None +# Returns: +# Optional[str]: None (deprecated functionality) +# """ +# import warnings +# warnings.warn( +# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) +# return None -def ensure_fresh_config(config_id = None) -> bool: - """ - DEPRECATED: Legacy configuration freshness check +# def ensure_fresh_config(config_id = None) -> bool: +# """ +# DEPRECATED: Legacy configuration freshness check - ⚠️ This function is deprecated and will be removed in a future version. - Use MemoryConfig objects with dependency injection instead. +# ⚠️ This function is deprecated and will be removed in a future version. +# Use MemoryConfig objects with dependency injection instead. - For new code, use: - - app.services.memory_agent_service.MemoryAgentService.load_memory_config() - - app.services.memory_storage_service.MemoryStorageService.load_memory_config() +# For new code, use: +# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() +# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - Args: - config_id: Configuration ID (deprecated) +# Args: +# config_id: Configuration ID (deprecated) - Returns: - bool: Always returns False (deprecated functionality) - """ - import logging - import warnings +# Returns: +# bool: Always returns False (deprecated functionality) +# """ +# import logging +# import warnings - logger = logging.getLogger(__name__) +# logger = logging.getLogger(__name__) - warnings.warn( - "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.", - DeprecationWarning, - stacklevel=2 - ) +# warnings.warn( +# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.", +# DeprecationWarning, +# stacklevel=2 +# ) - logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. " - "Use MemoryConfig objects with dependency injection instead.") +# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. " +# "Use MemoryConfig objects with dependency injection instead.") - return False +# return False diff --git a/api/app/core/memory/utils/embedder/embedder_utils.py b/api/app/core/memory/utils/embedder/embedder_utils.py index 86899e30..0a384a87 100644 --- a/api/app/core/memory/utils/embedder/embedder_utils.py +++ b/api/app/core/memory/utils/embedder/embedder_utils.py @@ -6,8 +6,9 @@ This module provides centralized functions for creating embedder clients. from typing import TYPE_CHECKING from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config from app.core.models.base import RedBearModelConfig +from app.db import get_db_context +from app.services.memory_config_service import MemoryConfigService if TYPE_CHECKING: from app.schemas.memory_config_schema import MemoryConfig @@ -66,7 +67,8 @@ def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient: raise ValueError("Embedding ID is required but was not provided") try: - embedder_config_dict = get_embedder_config(embedding_id) + with get_db_context() as db: + embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id) except Exception as e: raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e diff --git a/api/app/core/memory/utils/llm/llm_utils.py b/api/app/core/memory/utils/llm/llm_utils.py index a4b327da..19d76d68 100644 --- a/api/app/core/memory/utils/llm/llm_utils.py +++ b/api/app/core/memory/utils/llm/llm_utils.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING from app.core.memory.llm_tools.openai_client import OpenAIClient -from app.core.memory.utils.config.config_utils import get_model_config from app.core.models.base import RedBearModelConfig from pydantic import BaseModel +from sqlalchemy.orm import Session if TYPE_CHECKING: from app.schemas.memory_config_schema import MemoryConfig @@ -13,105 +13,225 @@ async def handle_response(response: type[BaseModel]) -> dict: return response.model_dump() -def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient: +class MemoryClientFactory: """ - Get LLM client from MemoryConfig object. + Factory for creating LLM, embedder, and reranker clients. - **PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object. - This ensures proper configuration management and multi-tenant support. + Initialize once with db session, then call methods without passing db each time. - Args: - memory_config: MemoryConfig object containing llm_model_id - - Returns: - OpenAIClient: Initialized LLM client - - Raises: - ValueError: If LLM model ID is not configured or client initialization fails - Example: - >>> llm_client = get_llm_client_from_config(memory_config) + >>> factory = MemoryClientFactory(db) + >>> llm_client = factory.get_llm_client(model_id) + >>> embedder_client = factory.get_embedder_client(embedding_id) """ - if not memory_config.llm_model_id: - raise ValueError( - f"Configuration {memory_config.config_id} has no LLM model configured" - ) - return get_llm_client(str(memory_config.llm_model_id)) + + def __init__(self, db: Session): + from app.services.memory_config_service import MemoryConfigService + self._config_service = MemoryConfigService(db) + + def get_llm_client(self, llm_id: str) -> OpenAIClient: + """Get LLM client by model ID.""" + if not llm_id: + raise ValueError("LLM ID is required") + + try: + model_config = self._config_service.get_model_config(llm_id) + except Exception as e: + raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e + + try: + return OpenAIClient( + RedBearModelConfig( + model_name=model_config.get("model_name"), + provider=model_config.get("provider"), + api_key=model_config.get("api_key"), + base_url=model_config.get("base_url") + ), + type_=model_config.get("type") + ) + except Exception as e: + model_name = model_config.get('model_name', 'unknown') + raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e + + def get_embedder_client(self, embedding_id: str): + """Get embedder client by model ID.""" + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + + if not embedding_id: + raise ValueError("Embedding ID is required") + + try: + embedder_config = self._config_service.get_embedder_config(embedding_id) + except Exception as e: + raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e + + try: + return OpenAIEmbedderClient( + RedBearModelConfig( + model_name=embedder_config.get("model_name"), + provider=embedder_config.get("provider"), + api_key=embedder_config.get("api_key"), + base_url=embedder_config.get("base_url") + ) + ) + except Exception as e: + model_name = embedder_config.get('model_name', 'unknown') + raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e + + def get_reranker_client(self, rerank_id: str) -> OpenAIClient: + """Get reranker client by model ID.""" + if not rerank_id: + raise ValueError("Rerank ID is required") + + try: + model_config = self._config_service.get_model_config(rerank_id) + except Exception as e: + raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e + + try: + return OpenAIClient( + RedBearModelConfig( + model_name=model_config.get("model_name"), + provider=model_config.get("provider"), + api_key=model_config.get("api_key"), + base_url=model_config.get("base_url") + ), + type_=model_config.get("type") + ) + except Exception as e: + model_name = model_config.get('model_name', 'unknown') + raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e + + def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient: + """Get LLM client from MemoryConfig object. + + Args: + memory_config: Configuration containing llm_model_id + + Returns: + OpenAIClient configured for the LLM model + + Raises: + ValueError: If memory_config has no LLM model configured + """ + if not memory_config.llm_model_id: + raise ValueError( + f"Configuration {memory_config.config_id} has no LLM model configured" + ) + return self.get_llm_client(str(memory_config.llm_model_id)) + + def get_embedder_client_from_config(self, memory_config: "MemoryConfig"): + """Get embedder client from MemoryConfig object. + + Args: + memory_config: Configuration containing embedding_model_id + + Returns: + OpenAIEmbedderClient configured for the embedding model + + Raises: + ValueError: If memory_config has no embedding model configured + """ + if not memory_config.embedding_model_id: + raise ValueError( + f"Configuration {memory_config.config_id} has no embedding model configured" + ) + return self.get_embedder_client(str(memory_config.embedding_model_id)) + + def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient: + """Get reranker client from MemoryConfig object. + + Args: + memory_config: Configuration containing rerank_model_id + + Returns: + OpenAIClient configured for the reranker model + + Raises: + ValueError: If memory_config has no rerank model configured + """ + if not memory_config.rerank_model_id: + raise ValueError( + f"Configuration {memory_config.config_id} has no rerank model configured" + ) + return self.get_reranker_client(str(memory_config.rerank_model_id)) -def get_llm_client(llm_id: str): - """ - Get LLM client by model ID. +# Legacy functions for backward compatibility +def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient: + """Get LLM client from MemoryConfig object. - **LEGACY/TEST METHOD**: Use this function only for: - - Test/evaluation code where you have a model ID directly - - Legacy code that hasn't been migrated to MemoryConfig yet + DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead. - For production code with MemoryConfig, use get_llm_client_from_config() instead. + This function is maintained for backward compatibility during migration to the + factory pattern. New code should create a MemoryClientFactory instance and use + its get_llm_client_from_config method directly. Args: - llm_id: LLM model ID (required) + memory_config: Configuration containing llm_model_id + db: Database session Returns: - OpenAIClient: Initialized LLM client + OpenAIClient configured for the LLM model Raises: - ValueError: If llm_id is not provided or client initialization fails - - Example: - >>> # For tests/evaluations only - >>> llm_client = get_llm_client("model-uuid-string") + ValueError: If memory_config has no LLM model configured """ - if not llm_id: - raise ValueError("LLM ID is required but was not provided") - - try: - model_config = get_model_config(llm_id) - except Exception as e: - raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e - - try: - llm_client = OpenAIClient(RedBearModelConfig( - model_name=model_config.get("model_name"), - provider=model_config.get("provider"), - api_key=model_config.get("api_key"), - base_url=model_config.get("base_url") - ),type_=model_config.get("type")) - return llm_client - except Exception as e: - model_name = model_config.get('model_name', 'unknown') - raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e + return MemoryClientFactory(db).get_llm_client_from_config(memory_config) -def get_reranker_client(rerank_id: str): - """ - Get an LLM client configured for reranking. +def get_llm_client(llm_id: str, db: Session) -> OpenAIClient: + """Get LLM client by model ID. + + DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead. + + This function is maintained for backward compatibility during migration to the + factory pattern. New code should create a MemoryClientFactory instance and use + its get_llm_client method directly. Args: - rerank_id: Reranker model ID (required) + llm_id: LLM model ID + db: Database session Returns: - OpenAIClient: Initialized client for the reranker model - - Raises: - ValueError: If rerank_id is not provided or client initialization fails + OpenAIClient configured for the LLM model """ - if not rerank_id: - raise ValueError("Rerank ID is required but was not provided") + return MemoryClientFactory(db).get_llm_client(llm_id) + + +def get_embedder_client(embedding_id: str, db: Session): + """Get embedder client by model ID. - try: - model_config = get_model_config(rerank_id) - except Exception as e: - raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e + DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead. - try: - reranker_client = OpenAIClient(RedBearModelConfig( - model_name=model_config.get("model_name"), - provider=model_config.get("provider"), - api_key=model_config.get("api_key"), - base_url=model_config.get("base_url") - ),type_=model_config.get("type")) - return reranker_client - except Exception as e: - model_name = model_config.get('model_name', 'unknown') - raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e \ No newline at end of file + This function is maintained for backward compatibility during migration to the + factory pattern. New code should create a MemoryClientFactory instance and use + its get_embedder_client method directly. + + Args: + embedding_id: Embedding model ID + db: Database session + + Returns: + OpenAIEmbedderClient configured for the embedding model + """ + return MemoryClientFactory(db).get_embedder_client(embedding_id) + + +def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient: + """Get reranker client by model ID. + + DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead. + + This function is maintained for backward compatibility during migration to the + factory pattern. New code should create a MemoryClientFactory instance and use + its get_reranker_client method directly. + + Args: + rerank_id: Reranker model ID + db: Database session + + Returns: + OpenAIClient configured for the reranker model + """ + return MemoryClientFactory(db).get_reranker_client(rerank_id) diff --git a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py b/api/app/core/memory/utils/self_reflexion_utils/evaluate.py index 0ea68461..4d1835cd 100644 --- a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py +++ b/api/app/core/memory/utils/self_reflexion_utils/evaluate.py @@ -6,11 +6,12 @@ """ import logging -from typing import List, Any import time +from typing import Any, List +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.prompt.template_render import render_evaluate_prompt -from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.db import get_db_context from app.schemas.memory_storage_schema import ConflictResultSchema from pydantic import BaseModel @@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]: 冲突记忆列表(JSON 数组)。 """ from app.core.memory.utils.config import definitions as config_defs - client = get_llm_client(config_defs.SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema) messages = [{"role": "user", "content": rendered_prompt}] print(f"提示词长度: {len(rendered_prompt)}") diff --git a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/reflexion.py index 6835b868..1b915118 100644 --- a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py +++ b/api/app/core/memory/utils/self_reflexion_utils/reflexion.py @@ -6,11 +6,12 @@ """ import logging -from typing import List, Any import time +from typing import Any, List +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.prompt.template_render import render_reflexion_prompt -from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.db import get_db_context from app.schemas.memory_storage_schema import ReflexionResultSchema from pydantic import BaseModel @@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]: 反思结果列表(JSON 数组)。 """ from app.core.memory.utils.config import definitions as config_defs - client = get_llm_client(config_defs.SELECTED_LLM_ID) + with get_db_context() as db: + factory = MemoryClientFactory(db) + client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema) messages = [{"role": "user", "content": rendered_prompt}] print(f"提示词长度: {len(rendered_prompt)}") diff --git a/api/app/core/rag_utils/chunk_insight.py b/api/app/core/rag_utils/chunk_insight.py index 2c96160e..e904e53d 100644 --- a/api/app/core/rag_utils/chunk_insight.py +++ b/api/app/core/rag_utils/chunk_insight.py @@ -5,16 +5,24 @@ This module provides functionality to analyze chunk content and generate insight """ import asyncio -from typing import List, Dict, Any from collections import Counter +from typing import Any, Dict, List + +from app.core.logging_config import get_business_logger +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from pydantic import BaseModel, Field -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.logging_config import get_business_logger - business_logger = get_business_logger() +def _get_llm_client(): + """Get LLM client using db context.""" + with get_db_context() as db: + factory = MemoryClientFactory(db) + return factory.get_llm_client(None) # Uses default LLM + + class ChunkInsight(BaseModel): """Pydantic model for chunk insight.""" insight: str = Field(..., description="对chunk内容的深度洞察分析") @@ -40,7 +48,7 @@ async def classify_chunk_domain(chunk: str) -> str: Domain name """ try: - llm_client = get_llm_client() + llm_client = _get_llm_client() prompt = f"""请将以下文本内容归类到最合适的领域中。 @@ -177,7 +185,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str ] # 调用LLM生成洞察 - llm_client = get_llm_client() + llm_client = _get_llm_client() response = await llm_client.chat(messages=messages) insight = response.content.strip() diff --git a/api/app/core/rag_utils/chunk_summary.py b/api/app/core/rag_utils/chunk_summary.py index 971d6907..7f69af88 100644 --- a/api/app/core/rag_utils/chunk_summary.py +++ b/api/app/core/rag_utils/chunk_summary.py @@ -5,15 +5,23 @@ This module provides functionality to summarize chunk content using LLM. """ import asyncio -from typing import List, Dict, Any +from typing import Any, Dict, List + +from app.core.logging_config import get_business_logger +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from pydantic import BaseModel, Field -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.logging_config import get_business_logger - business_logger = get_business_logger() +def _get_llm_client(): + """Get LLM client using db context.""" + with get_db_context() as db: + factory = MemoryClientFactory(db) + return factory.get_llm_client(None) # Uses default LLM + + class ChunkSummary(BaseModel): """Pydantic model for chunk summary.""" summary: str = Field(..., description="简洁的chunk内容摘要") @@ -59,7 +67,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str ] # 调用LLM生成摘要 - llm_client = get_llm_client() + llm_client = _get_llm_client() response = await llm_client.chat(messages=messages) summary = response.content.strip() diff --git a/api/app/core/rag_utils/chunk_tags.py b/api/app/core/rag_utils/chunk_tags.py index 719f97e6..2057f8ac 100644 --- a/api/app/core/rag_utils/chunk_tags.py +++ b/api/app/core/rag_utils/chunk_tags.py @@ -7,14 +7,22 @@ This module provides functionality to extract meaningful tags from chunk content import asyncio from collections import Counter from typing import List, Tuple + +from app.core.logging_config import get_business_logger +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from pydantic import BaseModel, Field -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.logging_config import get_business_logger - business_logger = get_business_logger() +def _get_llm_client(): + """Get LLM client using db context.""" + with get_db_context() as db: + factory = MemoryClientFactory(db) + return factory.get_llm_client(None) # Uses default LLM + + class ExtractedTags(BaseModel): """Pydantic model for extracted tags.""" tags: List[str] = Field(..., description="从文本中提取的关键标签列表") @@ -56,7 +64,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: "标签应该是名词或名词短语,能够准确概括文本的核心内容。" ) - llm_client = get_llm_client() + llm_client = _get_llm_client() # 为每个chunk单独提取标签,然后统计频率 all_tags = [] @@ -151,7 +159,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch ] # 调用LLM提取人物形象 - llm_client = get_llm_client() + llm_client = _get_llm_client() structured_response = await llm_client.response_structured( messages=messages, response_model=ExtractedPersona diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 2ea24be8..171abb7a 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -391,6 +391,29 @@ class MemoryConfig: embedding_params: Dict[str, Any] = field(default_factory=dict) config_version: str = "2.0" + # Pipeline config: Deduplication + enable_llm_dedup_blockwise: bool = False + enable_llm_disambiguation: bool = False + deep_retrieval: bool = True + t_type_strict: float = 0.8 + t_name_strict: float = 0.8 + t_overall: float = 0.8 + + # Pipeline config: Statement extraction + statement_granularity: int = 2 + include_dialogue_context: bool = False + max_dialogue_context_chars: int = 1000 + + # Pipeline config: Forgetting engine + lambda_time: float = 0.5 + lambda_mem: float = 0.5 + offset: float = 0.0 + + # Pipeline config: Pruning + pruning_enabled: bool = False + pruning_scene: Optional[str] = "education" + pruning_threshold: float = 0.5 + def __post_init__(self): """Validate configuration after initialization.""" if not self.config_name or not self.config_name.strip(): diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 489ffe4b..c0d2e3ff 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -3,26 +3,26 @@ 提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。 """ -import time -import uuid -import json import asyncio import datetime -from typing import Dict, Any, Optional, List, AsyncGenerator -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session -from sqlalchemy import select +import json +import time +import uuid +from typing import Any, AsyncGenerator, Dict, List, Optional -from app.models import AgentConfig, ModelConfig, ModelApiKey -from app.core.exceptions import BusinessException from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole +from app.core.rag.nlp.search import knowledge_retrieval +from app.models import AgentConfig, ModelApiKey, ModelConfig +from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message +from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger -from app.core.rag.nlp.search import knowledge_retrieval -from app.services.langchain_tool_server import Search +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str """ logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}") try: - memory_content = asyncio.run( - MemoryAgentService().read_memory( - group_id=end_user_id, - message=question, - history=[], - search_switch="1", - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + from app.db import get_db + db = next(get_db()) + try: + memory_content = asyncio.run( + MemoryAgentService().read_memory( + group_id=end_user_id, + message=question, + history=[], + search_switch="1", + config_id=config_id, + db=db, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ) ) - ) + finally: + db.close() logger.info(f'用户ID:Agent:{end_user_id}') logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) @@ -713,9 +719,9 @@ class DraftRunService: Raises: BusinessException: 当指定的会话不存在时 """ - from app.services.conversation_service import ConversationService - from app.schemas.conversation_schema import ConversationCreate from app.models import Conversation as ConversationModel + from app.schemas.conversation_schema import ConversationCreate + from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index 6952256e..2e63eeb0 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -7,14 +7,15 @@ Classes: EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能 """ -from typing import Dict, Any, Optional, List -import statistics import json -from pydantic import BaseModel, Field +import statistics +from typing import Any, Dict, List, Optional +from app.core.logging_config import get_business_logger from app.repositories.neo4j.emotion_repository import EmotionRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.logging_config import get_business_logger +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session logger = get_business_logger() @@ -454,7 +455,7 @@ class EmotionAnalyticsService: async def generate_emotion_suggestions( self, end_user_id: str, - config_id: Optional[int] = None + db: Session, ) -> Dict[str, Any]: """生成个性化情绪建议 @@ -462,7 +463,7 @@ class EmotionAnalyticsService: Args: end_user_id: 宿主ID(用户组ID) - config_id: 配置ID(可选,用于从数据库加载LLM配置) + db: 数据库会话 Returns: Dict: 包含个性化建议的响应: @@ -470,14 +471,32 @@ class EmotionAnalyticsService: - suggestions: 建议列表(3-5条) """ try: - logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}") + logger.info(f"生成个性化情绪建议: user={end_user_id}") - # 1. 如果提供了 config_id,从数据库加载配置 - if config_id is not None: - from app.core.memory.utils.config.definitions import reload_configuration_from_database - config_loaded = reload_configuration_from_database(config_id) - if not config_loaded: - logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置") + # 1. 从 end_user_id 获取关联的 memory_config_id + llm_client = None + try: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id is not None: + from app.services.memory_config_service import ( + MemoryConfigService, + ) + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=int(config_id), + service_name="EmotionAnalyticsService.generate_emotion_suggestions" + ) + from app.core.memory.client_factory import MemoryClientFactory + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(str(memory_config.llm_model_id)) + except Exception as e: + logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}") # 2. 获取情绪健康数据 health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d") @@ -498,8 +517,9 @@ class EmotionAnalyticsService: prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile) # 7. 调用LLM生成建议(使用配置中的LLM) - from app.core.memory.utils.llm.llm_utils import get_llm_client - llm_client = get_llm_client() + if llm_client is None: + # 无法获取配置时,抛出错误而不是使用默认配置 + raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config") # 将 prompt 转换为 messages 格式 messages = [ @@ -598,7 +618,9 @@ class EmotionAnalyticsService: Returns: str: LLM prompt """ - from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt + from app.core.memory.utils.prompt.prompt_utils import ( + render_emotion_suggestions_prompt, + ) prompt = await render_emotion_suggestions_prompt( health_data=health_data, diff --git a/api/app/services/emotion_extraction_service.py b/api/app/services/emotion_extraction_service.py index b3172df1..d134251d 100644 --- a/api/app/services/emotion_extraction_service.py +++ b/api/app/services/emotion_extraction_service.py @@ -9,10 +9,12 @@ Classes: import logging from typing import Optional -from app.core.memory.models.emotion_models import EmotionExtraction -from app.models.data_config_model import DataConfig -from app.core.memory.utils.llm.llm_utils import get_llm_client + from app.core.memory.llm_tools.llm_client import LLMClientException +from app.core.memory.models.emotion_models import EmotionExtraction +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context +from app.models.data_config_model import DataConfig logger = logging.getLogger(__name__) @@ -50,7 +52,9 @@ class EmotionExtractionService: """ if self.llm_client is None or model_id: effective_model_id = model_id or self.llm_id - self.llm_client = get_llm_client(effective_model_id) + with get_db_context() as db: + factory = MemoryClientFactory(db) + self.llm_client = factory.get_llm_client(effective_model_id) return self.llm_client async def extract_emotion( @@ -142,7 +146,9 @@ class EmotionExtractionService: Returns: Formatted prompt string for LLM """ - from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt + from app.core.memory.utils.prompt.prompt_utils import ( + render_emotion_extraction_prompt, + ) prompt = await render_emotion_extraction_prompt( statement=statement, diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 27fdfa48..e23f9471 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.db import get_db +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig @@ -45,8 +45,7 @@ config_logger = get_config_logger() # Initialize Neo4j connector for analytics functions _neo4j_connector = Neo4jConnector() -db_gen = get_db() -db = next(db_gen) + class MemoryAgentService: """Service for memory agent operations""" @@ -55,27 +54,6 @@ class MemoryAgentService: self.user_locks: Dict[str, Lock] = {} self.locks_lock = Lock() - def load_memory_config(self, config_id: int) -> MemoryConfig: - """ - Load memory configuration from database by config_id. - - This method delegates to the centralized MemoryConfigService to avoid - code duplication with other services. - - Args: - config_id: Configuration ID from database - - Returns: - MemoryConfig: Immutable configuration object - - Raises: - ConfigurationError: If validation fails - """ - return MemoryConfigService.load_memory_config( - config_id=config_id, - service_name="MemoryAgentService" - ) - def writer_messages_deal(self,messages,start_time,group_id,config_id,message): messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') countext = re.findall(r'"status": "(.*?)",', messages)[0] @@ -277,14 +255,17 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str: + async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id Args: - group_id: Group identifier + group_id: Group identifier (also used as end_user_id) message: Message to write config_id: Configuration ID from database + db: SQLAlchemy database session + storage_type: Storage type (neo4j or rag) + user_rag_memory_id: User RAG memory ID Returns: Write operation result status @@ -292,14 +273,24 @@ class MemoryAgentService: Raises: ValueError: If config loading fails or write operation fails """ - if config_id==None: - config_id = os.getenv("config_id") + # Resolve config_id if None using end_user's connected config + if config_id is None: + try: + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + except Exception as e: + logger.warning(f"Failed to get connected config for end_user {group_id}: {e}") + import time start_time = time.time() # Load configuration from database only try: - memory_config = self.load_memory_config(config_id) + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) logger.info(f"Configuration loaded successfully: {memory_config.config_name}") except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" @@ -366,6 +357,7 @@ class MemoryAgentService: history: List[Dict], search_switch: str, config_id: str, + db: Session, storage_type: str, user_rag_memory_id: str ) -> Dict: @@ -378,11 +370,14 @@ class MemoryAgentService: - "2": Direct answer based on context Args: - group_id: Group identifier + group_id: Group identifier (also used as end_user_id) message: User message history: Conversation history search_switch: Search mode switch config_id: Configuration ID from database + db: SQLAlchemy database session + storage_type: Storage type (neo4j or rag) + user_rag_memory_id: User RAG memory ID Returns: Dict with 'answer' and 'intermediate_outputs' keys @@ -394,8 +389,13 @@ class MemoryAgentService: import time start_time = time.time() - if config_id==None: - config_id = os.getenv("config_id") + # Resolve config_id if None using end_user's connected config + if config_id is None: + try: + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + except Exception as e: + logger.warning(f"Failed to get connected config for end_user {group_id}: {e}") logger.info(f"Read operation for group {group_id} with config_id {config_id}") @@ -411,7 +411,11 @@ class MemoryAgentService: with group_lock: # Step 1: Load configuration from database only try: - memory_config = self.load_memory_config(config_id) + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) logger.info(f"Configuration loaded successfully: {memory_config.config_name}") except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" @@ -696,7 +700,11 @@ class MemoryAgentService: logger.info("Classifying message type") # Load configuration to get LLM model ID - memory_config = self.load_memory_config(config_id) + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") @@ -865,7 +873,8 @@ class MemoryAgentService: self, end_user_id: Optional[str] = None, current_user_id: Optional[str] = None, - llm_id: Optional[str] = None + llm_id: Optional[str] = None, + db: Session = None ) -> Dict[str, Any]: """ 获取用户详情,包含: @@ -877,6 +886,7 @@ class MemoryAgentService: - end_user_id: 用户ID(可选) - current_user_id: 当前登录用户的ID(保留参数) - llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成) + - db: 数据库会话(可选) 返回格式: { @@ -893,7 +903,7 @@ class MemoryAgentService: # 1. 根据 end_user_id 获取 end_user_name try: - if end_user_id: + if end_user_id and db: from app.repositories import end_user_repository from app.schemas.end_user_schema import EndUser as EndUserSchema @@ -948,7 +958,9 @@ class MemoryAgentService: logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities") # 使用LLM提取标签 - llm_client = get_llm_client(llm_id) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(llm_id) # 定义标签提取的结构 class UserTags(BaseModel): @@ -1110,4 +1122,69 @@ class MemoryAgentService: # "msg": "解析失败", # "error_code": "DOC_PARSE_ERROR", # "data": {"error": str(e)} -# } \ No newline at end of file +# } + + +def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: + """ + 获取终端用户关联的记忆配置 + + 通过以下流程获取配置: + 1. 根据 end_user_id 获取用户的 app_id + 2. 获取该应用的最新发布版本 + 3. 从发布版本的 config 字段中提取 memory_config_id + + Args: + end_user_id: 终端用户ID + db: 数据库会话 + + Returns: + 包含 memory_config_id 和相关信息的字典 + + Raises: + ValueError: 当终端用户不存在或应用未发布时 + """ + from app.models.app_release_model import AppRelease + from app.models.end_user_model import EndUser + from sqlalchemy import select + + logger.info(f"Getting connected config for end_user: {end_user_id}") + + # 1. 获取 end_user 及其 app_id + end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + if not end_user: + logger.warning(f"End user not found: {end_user_id}") + raise ValueError(f"终端用户不存在: {end_user_id}") + + app_id = end_user.app_id + logger.debug(f"Found end_user app_id: {app_id}") + + # 2. 获取该应用的最新发布版本 + stmt = ( + select(AppRelease) + .where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True)) + .order_by(AppRelease.version.desc()) + ) + latest_release = db.scalars(stmt).first() + + if not latest_release: + logger.warning(f"No active release found for app: {app_id}") + raise ValueError(f"应用未发布: {app_id}") + + logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}") + + # 3. 从 config 中提取 memory_config_id + config = latest_release.config or {} + memory_obj = config.get('memory', {}) + memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + + result = { + "end_user_id": str(end_user_id), + "app_id": str(app_id), + "release_id": str(latest_release.id), + "release_version": latest_release.version, + "memory_config_id": memory_config_id + } + + logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") + return result \ No newline at end of file diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index d19eb02a..3413ebd6 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -3,7 +3,6 @@ Memory Configuration Service Centralized configuration loading and management for memory services. This service eliminates code duplication between MemoryAgentService and MemoryStorageService. -Database session management is handled internally. """ import time @@ -57,7 +56,7 @@ def _validate_config_id(config_id): invalid_value=config_id, ) return parsed_id - except ValueError as e: + except ValueError: raise InvalidConfigError( f"Invalid configuration ID format: '{config_id}'", field_name="config_id", @@ -77,19 +76,29 @@ class MemoryConfigService: This class provides a single implementation of configuration loading logic that can be shared across multiple services, eliminating code duplication. - Database session management is handled internally. + + Usage: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + model_config = config_service.get_model_config(model_id) """ - @staticmethod + def __init__(self, db: Session): + """Initialize the service with a database session. + + Args: + db: SQLAlchemy database session + """ + self.db = db + def load_memory_config( + self, config_id: int, service_name: str = "MemoryConfigService", ) -> MemoryConfig: """ Load memory configuration from database by config_id. - This method manages its own database session internally. - Args: config_id: Configuration ID from database service_name: Name of the calling service (for logging purposes) @@ -100,27 +109,6 @@ class MemoryConfigService: Raises: ConfigurationError: If validation fails """ - from app.db import get_db - - db_gen = get_db() - db = next(db_gen) - - try: - return MemoryConfigService._load_memory_config_with_db( - config_id=config_id, - db=db, - service_name=service_name, - ) - finally: - db.close() - - @staticmethod - def _load_memory_config_with_db( - config_id: int, - db: Session, - service_name: str = "MemoryConfigService", - ) -> MemoryConfig: - """Internal method that loads memory configuration with an existing db session.""" start_time = time.time() config_logger.info( @@ -137,7 +125,7 @@ class MemoryConfigService: try: validated_config_id = _validate_config_id(config_id) - result = DataConfigRepository.get_config_with_workspace(db, validated_config_id) + result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) if not result: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -160,7 +148,7 @@ class MemoryConfigService: embedding_uuid = validate_embedding_model( validated_config_id, memory_config.embedding_id, - db, + self.db, workspace.tenant_id, workspace.id, ) @@ -169,7 +157,7 @@ class MemoryConfigService: llm_uuid, llm_name = validate_and_resolve_model_id( memory_config.llm_id, "llm", - db, + self.db, workspace.tenant_id, required=True, config_id=validated_config_id, @@ -183,7 +171,7 @@ class MemoryConfigService: rerank_uuid, rerank_name = validate_and_resolve_model_id( memory_config.rerank_id, "rerank", - db, + self.db, workspace.tenant_id, required=False, config_id=validated_config_id, @@ -194,7 +182,7 @@ class MemoryConfigService: embedding_name, _ = validate_model_exists_and_active( embedding_uuid, "embedding", - db, + self.db, workspace.tenant_id, config_id=validated_config_id, workspace_id=workspace.id, @@ -220,6 +208,25 @@ class MemoryConfigService: reflexion_range=memory_config.reflexion_range or "retrieval", reflexion_baseline=memory_config.baseline or "time", loaded_at=datetime.now(), + # Pipeline config: Deduplication + enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, + enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, + deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, + t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, + t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, + t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, + # Pipeline config: Statement extraction + statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, + include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, + max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, + # Pipeline config: Forgetting engine + lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, + lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, + offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, + # Pipeline config: Pruning + pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, + pruning_scene=memory_config.pruning_scene or "education", + pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, ) elapsed_ms = (time.time() - start_time) * 1000 @@ -262,3 +269,131 @@ class MemoryConfigService: raise else: raise ConfigurationError(f"Failed to load configuration {config_id}: {e}") + + def get_model_config(self, model_id: str) -> dict: + """Get LLM model configuration by ID. + + Args: + model_id: Model ID to look up + + Returns: + Dict with model configuration including api_key, base_url, etc. + """ + from app.core.config import settings + from app.models.models_model import ModelApiKey + from app.services.model_service import ModelConfigService as ModelSvc + from fastapi import status + from fastapi.exceptions import HTTPException + + config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id) + if not config: + logger.warning(f"Model ID {model_id} not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") + + api_config: ModelApiKey = config.api_keys[0] + + return { + "model_name": api_config.model_name, + "provider": api_config.provider, + "api_key": api_config.api_key, + "base_url": api_config.api_base, + "model_config_id": api_config.model_config_id, + "type": config.type, + "timeout": settings.LLM_TIMEOUT, + "max_retries": settings.LLM_MAX_RETRIES, + } + + def get_embedder_config(self, embedding_id: str) -> dict: + """Get embedding model configuration by ID. + + Args: + embedding_id: Embedding model ID to look up + + Returns: + Dict with embedder configuration including api_key, base_url, etc. + """ + from app.models.models_model import ModelApiKey + from app.services.model_service import ModelConfigService as ModelSvc + from fastapi import status + from fastapi.exceptions import HTTPException + + config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id) + if not config: + logger.warning(f"Embedding model ID {embedding_id} not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") + + api_config: ModelApiKey = config.api_keys[0] + + return { + "model_name": api_config.model_name, + "provider": api_config.provider, + "api_key": api_config.api_key, + "base_url": api_config.api_base, + "model_config_id": api_config.model_config_id, + "type": config.type, + "timeout": 120.0, + "max_retries": 5, + } + + @staticmethod + def get_pipeline_config(memory_config: MemoryConfig): + """Build ExtractionPipelineConfig from MemoryConfig. + + Args: + memory_config: MemoryConfig object containing all pipeline settings. + + Returns: + ExtractionPipelineConfig with deduplication, statement extraction, + and forgetting engine settings. + """ + from app.core.memory.models.variate_config import ( + DedupConfig, + ExtractionPipelineConfig, + ForgettingEngineConfig, + StatementExtractionConfig, + ) + + dedup_config = DedupConfig( + enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise, + enable_llm_disambiguation=memory_config.enable_llm_disambiguation, + fuzzy_name_threshold_strict=memory_config.t_name_strict, + fuzzy_type_threshold_strict=memory_config.t_type_strict, + fuzzy_overall_threshold=memory_config.t_overall, + ) + + stmt_config = StatementExtractionConfig( + statement_granularity=memory_config.statement_granularity, + include_dialogue_context=memory_config.include_dialogue_context, + max_dialogue_context_chars=memory_config.max_dialogue_context_chars, + ) + + forget_config = ForgettingEngineConfig( + offset=memory_config.offset, + lambda_time=memory_config.lambda_time, + lambda_mem=memory_config.lambda_mem, + ) + + return ExtractionPipelineConfig( + statement_extraction=stmt_config, + deduplication=dedup_config, + forgetting_engine=forget_config, + ) + + @staticmethod + def get_pruning_config(memory_config: MemoryConfig) -> dict: + """Retrieve semantic pruning config from MemoryConfig. + + Args: + memory_config: MemoryConfig object containing pruning settings. + + Returns: + Dict suitable for PruningConfig.model_validate with keys: + - pruning_switch: bool + - pruning_scene: str + - pruning_threshold: float + """ + return { + "pruning_switch": memory_config.pruning_enabled, + "pruning_scene": memory_config.pruning_scene, + "pruning_threshold": memory_config.pruning_threshold, + } diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index bee3d22a..c88dd1d5 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -49,27 +49,6 @@ class MemoryStorageService: def __init__(self): logger.info("MemoryStorageService initialized") - - def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig: - """ - Load memory configuration from database by config_id. - - This method delegates to the centralized MemoryConfigService to avoid - code duplication with other services. - - Args: - config_id: Configuration ID from database - - Returns: - MemoryConfig: Immutable configuration object - - Raises: - ConfigurationError: If validation fails - """ - return MemoryConfigService.load_memory_config( - config_id=config_id, - service_name="MemoryStorageService" - ) async def get_storage_info(self) -> dict: """ @@ -293,7 +272,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # Load configuration from database only using centralized manager try: - memory_config = MemoryConfigService.load_memory_config( + config_service = MemoryConfigService(self.db) + memory_config = config_service.load_memory_config( config_id=int(cid), service_name="MemoryStorageService.pilot_run_stream" ) @@ -320,13 +300,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) async def run_pipeline(): """在后台执行管线并捕获异常""" try: - from app.core.memory.main import main as pipeline_main + from app.services.pilot_run_service import run_pilot_extraction - logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True") - await pipeline_main( - dialogue_text=dialogue_text, - is_pilot_run=True, - progress_callback=progress_callback + logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") + await run_pilot_extraction( + memory_config=memory_config, + dialogue_text=dialogue_text, + db=self.db, + progress_callback=progress_callback, ) logger.info("[PILOT_RUN_STREAM] pipeline_main completed") diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py new file mode 100644 index 00000000..17dfd7eb --- /dev/null +++ b/api/app/services/pilot_run_service.py @@ -0,0 +1,219 @@ +""" +Pilot Run Service - 试运行服务 + +用于执行记忆系统的试运行流程,不保存到 Neo4j。 +""" + +import os +import re +import time +from datetime import datetime +from typing import Awaitable, Callable, Optional + +from app.core.logging_config import get_memory_logger, log_time +from app.core.memory.models.message_models import ( + ConversationContext, + ConversationMessage, + DialogData, +) +from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( + ExtractionOrchestrator, + get_chunked_dialogs_from_preprocessed, +) +from app.core.memory.utils.config.config_utils import ( + get_pipeline_config, +) +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.memory_config_schema import MemoryConfig +from sqlalchemy.orm import Session + +logger = get_memory_logger(__name__) + + +async def run_pilot_extraction( + memory_config: MemoryConfig, + dialogue_text: str, + db: Session, + progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None, +) -> None: + """ + 执行试运行模式的知识提取流水线。 + + Args: + memory_config: 从数据库加载的内存配置对象 + dialogue_text: 输入的对话文本 + progress_callback: 可选的进度回调函数 + - 参数1 (stage): 当前处理阶段标识符 + - 参数2 (message): 人类可读的进度消息 + - 参数3 (data): 可选的附加数据字典 + """ + log_file = "logs/time.log" + os.makedirs(os.path.dirname(log_file), exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with open(log_file, "a", encoding="utf-8") as f: + f.write(f"\n=== Pilot Run Started: {timestamp} ===\n") + + pipeline_start = time.time() + neo4j_connector = None + + try: + # 步骤 1: 初始化客户端 + logger.info("Initializing clients...") + step_start = time.time() + + client_factory = MemoryClientFactory(db) + llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id)) + embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id)) + + neo4j_connector = Neo4jConnector() + + log_time("Client Initialization", time.time() - step_start, log_file) + + # 步骤 2: 解析对话文本 + logger.info("Parsing dialogue text...") + step_start = time.time() + + # 解析对话文本,支持 "用户:" 和 "AI:" 格式 + pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)" + matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL) + messages = [ + ConversationMessage(role=r, msg=c.strip()) + for r, c in matches + if c.strip() + ] + + # 如果没有匹配到格式化的对话,将整个文本作为用户消息 + if not messages: + messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())] + + context = ConversationContext(msgs=messages) + dialog = DialogData( + context=context, + ref_id="pilot_dialog_1", + group_id=str(memory_config.workspace_id), + user_id=str(memory_config.tenant_id), + apply_id=str(memory_config.config_id), + metadata={"source": "pilot_run", "input_type": "frontend_text"}, + ) + + if progress_callback: + await progress_callback("text_preprocessing", "开始预处理文本...") + + chunked_dialogs = await get_chunked_dialogs_from_preprocessed( + data=[dialog], + chunker_strategy=memory_config.chunker_strategy, + llm_client=llm_client, + ) + logger.info(f"Processed dialogue text: {len(messages)} messages") + + # 进度回调:输出每个分块的结果 + if progress_callback: + for dlg in chunked_dialogs: + for i, chunk in enumerate(dlg.chunks): + chunk_result = { + "chunk_index": i + 1, + "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, + "full_length": len(chunk.content), + "dialog_id": dlg.id, + "chunker_strategy": memory_config.chunker_strategy, + } + await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + + preprocessing_summary = { + "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs), + "total_dialogs": len(chunked_dialogs), + "chunker_strategy": memory_config.chunker_strategy, + } + await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) + + log_time("Data Loading & Chunking", time.time() - step_start, log_file) + + # 步骤 3: 初始化流水线编排器 + logger.info("Initializing extraction orchestrator...") + step_start = time.time() + + config = get_pipeline_config(memory_config) + logger.info( + f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, " + f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}" + ) + + orchestrator = ExtractionOrchestrator( + llm_client=llm_client, + embedder_client=embedder_client, + connector=neo4j_connector, + config=config, + progress_callback=progress_callback, + embedding_id=str(memory_config.embedding_model_id), + ) + + log_time("Orchestrator Initialization", time.time() - step_start, log_file) + + # 步骤 4: 执行知识提取流水线 + logger.info("Running extraction pipeline...") + step_start = time.time() + + if progress_callback: + await progress_callback("knowledge_extraction", "正在知识抽取...") + + extraction_result = await orchestrator.run( + dialog_data_list=chunked_dialogs, + is_pilot_run=True, + ) + + # 解包 extraction_result tuple (与 main.py 保持一致) + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_edges, + ) = extraction_result + + log_time("Extraction Pipeline", time.time() - step_start, log_file) + + if progress_callback: + await progress_callback("generating_results", "正在生成结果...") + + # 步骤 5: 生成记忆摘要(与 main.py 保持一致) + try: + logger.info("Generating memory summaries...") + step_start = time.time() + + from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( + memory_summary_generation, + ) + + summaries = await memory_summary_generation( + chunked_dialogs, + llm_client=llm_client, + embedder_client=embedder_client, + ) + + log_time("Memory Summary Generation", time.time() - step_start, log_file) + except Exception as e: + logger.error(f"Memory summary step failed: {e}", exc_info=True) + + logger.info("Pilot run completed: Skipping Neo4j save") + + except Exception as e: + logger.error(f"Pilot run failed: {e}", exc_info=True) + raise + finally: + if neo4j_connector: + try: + await neo4j_connector.close() + except Exception: + pass + + total_time = time.time() - pipeline_start + log_time("TOTAL PILOT RUN TIME", total_time, log_file) + + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with open(log_file, "a", encoding="utf-8") as f: + f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n") + + logger.info(f"Pilot run complete. Total time: {total_time:.2f}s") diff --git a/api/app/tasks.py b/api/app/tasks.py index c4d9fc10..362172f0 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -176,7 +176,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, """Celery task to process a read message via MemoryAgentService. Args: - group_id: Group ID for the memory agent + group_id: Group ID for the memory agent (also used as end_user_id) message: User message to process history: Conversation history search_switch: Search switch parameter @@ -190,9 +190,28 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, """ start_time = time.time() + # Resolve config_id if None + actual_config_id = config_id + if actual_config_id is None: + try: + from app.services.memory_agent_service import get_end_user_connected_config + db = next(get_db()) + try: + connected_config = get_end_user_connected_config(group_id, db) + actual_config_id = connected_config.get("memory_config_id") + finally: + db.close() + except Exception as e: + # Log but continue - will fail later with proper error + pass + async def _run() -> str: - service = MemoryAgentService() - return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id) + db = next(get_db()) + try: + service = MemoryAgentService() + return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) + finally: + db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -246,7 +265,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage """Celery task to process a write message via MemoryAgentService. Args: - group_id: Group ID for the memory agent + group_id: Group ID for the memory agent (also used as end_user_id) message: Message to write config_id: Optional configuration ID @@ -258,9 +277,28 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage """ start_time = time.time() + # Resolve config_id if None + actual_config_id = config_id + if actual_config_id is None: + try: + from app.services.memory_agent_service import get_end_user_connected_config + db = next(get_db()) + try: + connected_config = get_end_user_connected_config(group_id, db) + actual_config_id = connected_config.get("memory_config_id") + finally: + db.close() + except Exception as e: + # Log but continue - will fail later with proper error + pass + async def _run() -> str: - service = MemoryAgentService() - return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id) + db = next(get_db()) + try: + service = MemoryAgentService() + return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + finally: + db.close() try: # 使用 nest_asyncio 来避免事件循环冲突