refactor(memory): restructure memory system and improve configuration management

- Remove deprecated main.py entry point from memory module
- Reorganize imports across controllers and services for consistency
- Update emotion controller to pass db session instead of config_id to services
- Enhance memory agent controller with db session parameter for status_type and user_profile endpoints
- Refactor memory agent service to accept db parameter in classify_message_type method
- Improve configuration handling in celery_app by removing automatic database reload
- Update all memory-related services to use centralized config management
- Standardize import ordering and remove unused imports across 50+ files
- Add pilot_run_service for new pilot execution workflow
- Refactor extraction engine, reflection engine, and search services for better modularity
- Update LLM utilities and embedder configuration for improved flexibility
- Enhance type classifier and verification tools with better error handling
- Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
Ke Sun
2025-12-23 17:17:04 +08:00
parent 258b88276f
commit 283c64a358
58 changed files with 2171 additions and 1797 deletions

View File

@@ -1,9 +1,9 @@
import os
from datetime import timedelta
from urllib.parse import quote
from celery import Celery
from app.core.config import settings
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from celery import Celery
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
@@ -13,7 +13,6 @@ celery_app = Celery(
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
@@ -22,6 +21,7 @@ celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')

View File

@@ -10,22 +10,21 @@ Routes:
POST /emotion/suggestions - 获取个性化情绪建议
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user, get_db
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.emotion_schema import (
EmotionHealthRequest,
EmotionSuggestionsRequest,
EmotionTagsRequest,
EmotionWordcloudRequest,
EmotionHealthRequest,
EmotionSuggestionsRequest
)
from app.schemas.response_schema import ApiResponse
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.core.logging_config import get_api_logger
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
# 获取API专用日志器
api_logger = get_api_logger()
@@ -230,7 +229,7 @@ async def get_emotion_suggestions(
# 调用服务层
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.group_id,
config_id=config_id
db=db
)
api_logger.info(

View File

@@ -163,7 +163,8 @@ async def write_server(
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
config_id,
config_id,
db,
storage_type,
user_rag_memory_id
)
@@ -280,6 +281,7 @@ async def read_server(
user_input.history,
user_input.search_switch,
config_id,
db,
storage_type,
user_rag_memory_id
)
@@ -548,6 +550,7 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
@@ -561,7 +564,11 @@ async def status_type(
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
result = await memory_agent_service.classify_message_type(user_input.message)
result = await memory_agent_service.classify_message_type(
user_input.message,
user_input.config_id,
db
)
return success(data=result)
except Exception as e:
api_logger.error(f"Message type classification failed: {str(e)}")
@@ -636,6 +643,7 @@ async def get_hot_memory_tags_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
@@ -659,7 +667,8 @@ async def get_user_profile_api(
try:
result = await memory_agent_service.get_user_profile(
end_user_id=end_user_id,
current_user_id=str(current_user.id)
current_user_id=str(current_user.id),
db=db
)
return success(data=result, msg="获取用户详情成功")
except Exception as e:
@@ -694,4 +703,41 @@ async def get_user_profile_api(
# )
# except Exception as e:
# api_logger.error(f"API docs retrieval failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
Returns:
包含 memory_config_id 和相关信息的响应
"""
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
except ValueError as e:
api_logger.warning(f"End user config not found: {str(e)}")
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -1,22 +1,27 @@
import asyncio
import time
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
ReflectionConfig,
ReflectionEngine,
)
from app.core.response_utils import success
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
from app.dependencies import get_current_user
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
MemoryReflectionService,
WorkspaceAppService,
)
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.orm import Session
load_dotenv()
api_logger = get_api_logger()

View File

@@ -9,18 +9,19 @@ LangChain Agent 封装
"""
import os
import time
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger()
@@ -198,10 +199,24 @@ class LangChainAgent:
"""
message_chat= message
start_time = time.time()
if config_id == None:
actual_config_id = os.getenv("config_id")
else:
actual_config_id = config_id
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
@@ -295,10 +310,24 @@ class LangChainAgent:
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
message_chat = message
if config_id == None:
actual_config_id = os.getenv("config_id")
else:
actual_config_id = config_id
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id)
if memory_flag:

View File

@@ -1,7 +1,8 @@
import os
import json
import os
from pathlib import Path
from typing import Dict, Any, Optional
from typing import Any, Dict, Optional
from dotenv import load_dotenv
load_dotenv()
@@ -81,6 +82,7 @@ class Settings:
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
# Langfuse configuration
LANGFUSE_ENABLED: bool = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
@@ -153,9 +155,6 @@ class Settings:
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
@@ -178,65 +177,6 @@ class Settings:
return str(base_path / filename)
return str(base_path)
def get_memory_config_path(self, config_file: str = "") -> str:
"""
Get the full path for memory module configuration files.
Args:
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
Returns:
Full path to the config file
"""
if not config_file:
config_file = self.MEMORY_CONFIG_FILE
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
def load_memory_config(self) -> Dict[str, Any]:
"""
Load memory module configuration from config.json.
Returns:
Dictionary containing memory configuration
"""
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
try:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
return {}
def load_memory_runtime_config(self) -> Dict[str, Any]:
"""
Load memory module runtime configuration from runtime.json.
Returns:
Dictionary containing runtime configuration
"""
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
try:
with open(runtime_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
return {"selections": {}}
def load_memory_dbrun_config(self) -> Dict[str, Any]:
"""
Load memory module database run configuration from dbrun.json.
Returns:
Dictionary containing dbrun configuration
"""
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
return {"selections": {}}
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.

View File

@@ -141,7 +141,7 @@ class SearchService:
cleaned_query = self.clean_query(question)
try:
# Execute search using embedding_model_id from memory_config
# Execute search using memory_config
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
@@ -149,7 +149,7 @@ class SearchService:
limit=limit,
include=include,
output_path=output_path,
embedding_id=str(config.embedding_model_id),
memory_config=config,
rerank_alpha=rerank_alpha,
)

View File

@@ -13,7 +13,8 @@ from app.core.memory.agent.mcp_server.models.retrieval_models import (
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
@@ -41,8 +42,10 @@ async def Data_type_differentiation(
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Get LLM client from memory_config using factory pattern
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Render template
try:

View File

@@ -16,7 +16,8 @@ from app.core.memory.agent.mcp_server.models.problem_models import (
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
@@ -56,7 +57,9 @@ async def Split_The_Problem(
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
@@ -190,7 +193,9 @@ async def Problem_Extension(
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username

View File

@@ -21,8 +21,9 @@ from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Summary_messages_deal,
)
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
@@ -66,7 +67,9 @@ async def Summary(
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
@@ -210,7 +213,9 @@ async def Retrieve_Summary(
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
@@ -425,7 +430,9 @@ async def Input_Summary(
search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""

View File

@@ -1,15 +1,14 @@
"""
Type classification utility for distinguishing read/write operations.
"""
from jinja2 import Template
from pydantic import BaseModel
from app.core.config import settings
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from jinja2 import Template
from pydantic import BaseModel
logger = get_agent_logger(__name__)
@@ -44,7 +43,9 @@ async def status_typle(messages: str, llm_model_id: str) -> dict:
"message": f"Prompt rendering failed: {str(e)}"
}
llm_client = get_llm_client(llm_model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_model_id)
try:
structured = await llm_client.response_structured(

View File

@@ -1,18 +1,19 @@
from typing import TypedDict, Annotated, List, Any
from langchain_core.messages import AnyMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
import asyncio
import json
from dotenv import load_dotenv, find_dotenv
import os
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
from typing import Annotated, Any, List, TypedDict
# Removed global variable imports - use dependency injection instead
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from dotenv import find_dotenv, load_dotenv
from jinja2 import Environment, FileSystemLoader
from langchain_core.messages import AnyMessage, HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph, add_messages
load_dotenv(find_dotenv())
@@ -53,7 +54,9 @@ class VerifyTool:
async def model_1(self, state: State) -> State:
if not self.llm_model_id:
raise ValueError("llm_model_id is required but not provided")
llm_client = get_llm_client(self.llm_model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
)

View File

@@ -13,13 +13,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
memory_summary_generation,
)
from app.core.memory.utils.embedder.embedder_utils import (
get_embedder_client_from_config,
)
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
@@ -67,9 +65,11 @@ async def write(
logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}")
# Construct clients from memory_config
llm_client = get_llm_client_from_config(memory_config)
embedder_client = get_embedder_client_from_config(memory_config)
# Construct clients from memory_config using factory pattern with db session
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
embedder_client = factory.get_embedder_client_from_config(memory_config)
logger.info("LLM and embedding clients constructed")
# Initialize timing log
@@ -100,7 +100,7 @@ async def write(
# Step 2: Initialize and run ExtractionOrchestrator
step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config
pipeline_config = get_pipeline_config()
pipeline_config = get_pipeline_config(memory_config)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
@@ -155,8 +155,8 @@ async def write(
# Step 4: Generate Memory summaries and save to Neo4j
step_start = time.time()
try:
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
)
try:

View File

@@ -35,7 +35,9 @@ except NameError:
import json
from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
# Default values (previously from definitions.py)
@@ -47,11 +49,37 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]:
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
"""
try:
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
# 3. 构建Prompt
tag_list_str = ", ".join(tags)
@@ -156,8 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
llm_client = get_llm_client(DEFAULT_LLM_ID)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []

View File

@@ -18,8 +18,10 @@ if src_path not in sys.path:
sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
#TODO: Fix this
@@ -59,7 +61,33 @@ class MemoryInsight:
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
self.llm_client = get_llm_client(DEFAULT_LLM_ID)
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
"""关闭数据库连接。"""

View File

@@ -25,8 +25,10 @@ except Exception:
pass
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
@@ -47,7 +49,33 @@ class UserSummary:
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
self.llm = get_llm_client(DEFAULT_LLM_ID)
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
await self.connector.close()

View File

@@ -1,22 +1,34 @@
import os
import asyncio
import json
from typing import List, Dict, Any, Optional
from datetime import datetime
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
@@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline(
llm_available = True
try:
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
llm_available = False
@@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline(
return False
# 初始化 embedder 客户端
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
try:
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
except Exception as e:
@@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline(
print("[Ingestion] Generating memory summaries...")
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
summaries = await Memory_summary_generation(
summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list,
llm_client=llm_client,
embedding_id=embedding_name or SELECTED_EMBEDDING_ID
embedder_client=embedder_client
)
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e:

View File

@@ -15,7 +15,7 @@ import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
try:
from dotenv import load_dotenv
@@ -23,37 +23,38 @@ except ImportError:
def load_dotenv():
pass
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
SELECTED_EMBEDDING_ID
)
from app.core.memory.utils.llm_utils import get_llm_client
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
f1_score,
avg_context_tokens,
bleu1,
f1_score,
jaccard,
latency_stats,
avg_context_tokens
)
from app.core.memory.evaluation.locomo.locomo_metrics import (
get_category_name,
locomo_f1_score,
locomo_multi_f1,
get_category_name
)
from app.core.memory.evaluation.locomo.locomo_utils import (
load_locomo_data,
extract_conversations,
ingest_conversations_if_needed,
load_locomo_data,
resolve_temporal_references,
select_and_format_information,
retrieve_relevant_information,
ingest_conversations_if_needed
select_and_format_information,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
@@ -160,10 +161,16 @@ async def run_locomo_benchmark(
# Step 3: Initialize clients
print("🔧 Initializing clients...")
connector = Neo4jConnector()
llm_client = get_llm_client(SELECTED_LLM_ID)
# Initialize LLM client with database context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Initialize embedder
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -1,14 +1,16 @@
# file name: check_neo4j_connection_fixed.py
import asyncio
import os
import sys
import json
import time
import math
import os
import re
import sys
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import Any, Dict, List
from dotenv import load_dotenv
# 1
# 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str:
# 尝试从 metrics.py 导入基础指标
try:
from common.metrics import f1_score, bleu1, jaccard
from common.metrics import bleu1, f1_score, jaccard
print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}")
@@ -111,10 +113,14 @@ try:
# 尝试从不同位置导入
try:
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError:
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e:
@@ -429,13 +435,17 @@ async def run_enhanced_evaluation():
return None
# 修正导入路径:使用 app.core.memory.src 前缀
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 加载数据
# 获取项目根目录
@@ -458,10 +468,14 @@ async def run_enhanced_evaluation():
# 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
llm = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -2,10 +2,11 @@ import argparse
import asyncio
import json
import os
import statistics
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
import statistics
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
@@ -13,16 +14,31 @@ except Exception:
return None
import re
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
@@ -327,9 +343,13 @@ async def run_locomo_eval(
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
# 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -2,11 +2,11 @@ import argparse
import asyncio
import json
import os
import time
import re
import statistics
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
@@ -16,6 +16,7 @@ except Exception:
# 确保可以找到 src 及项目根路径
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
@@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
# 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
try:
# 优先从 extraction_utils1 导入
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline, # type: ignore
)
except Exception:
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import exact_match
except Exception:
@@ -686,9 +701,13 @@ async def run_longmemeval_test(
)
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector()
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
@@ -748,10 +767,10 @@ async def run_longmemeval_test(
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# for sm in summaries:
# summary_text = str(sm.get("summary", "")).strip()
# if summary_text:
# contexts_all.append(summary_text)
# 实体摘要最多3个
scored = [e for e in entities if e.get("score") is not None]

View File

@@ -2,11 +2,11 @@ import argparse
import asyncio
import json
import os
import time
import re
import statistics
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
@@ -15,15 +15,26 @@ except Exception:
return None
# 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.core.memory.utils.llm_utils import get_llm_client
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import exact_match
except Exception:
@@ -647,9 +658,13 @@ async def run_longmemeval_test(
items = qa_list[start_index:start_index + sample_size]
# 初始化组件 - 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector()
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -4,19 +4,35 @@ import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
@@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]:
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
@@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
await ingest_contexts_via_full_pipeline(contexts, group_id)
# LLM client (使用异步调用)
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Evaluate each item
connector = Neo4jConnector()
@@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
memory_config=memory_config,
)
except Exception:
results = None
@@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
from app.core.memory.evaluation.common.metrics import (
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference)))

View File

@@ -2,10 +2,10 @@ import argparse
import asyncio
import json
import os
import re
import time
from datetime import datetime
from typing import List, Dict, Any
import re
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
@@ -15,6 +15,7 @@ except Exception:
# 路径与模块导入保持与现有评估脚本一致
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
@@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
sys.path.insert(0, _p)
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.client_factory import MemoryClientFactory
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
try:
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
except Exception:
# 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float:
@@ -226,13 +237,17 @@ async def run_memsciqa_test(
items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入
llm = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector()
embedder = None
if search_type in ("embedding", "hybrid"):
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -5,18 +5,17 @@ OpenAI LLM 客户端实现
"""
import asyncio
from typing import List, Dict, Any
import json
import logging
from typing import Any, Dict, List
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from app.core.config import settings
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
from app.core.models.base import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@@ -43,7 +42,7 @@ class OpenAIClient(LLMClient):
# 初始化 Langfuse 回调处理器(如果启用)
self.langfuse_handler = None
if LANGFUSE_ENABLED:
if settings.LANGFUSE_ENABLED:
try:
from langfuse.langchain import CallbackHandler
self.langfuse_handler = CallbackHandler()

View File

@@ -1,430 +0,0 @@
"""
MemSci 记忆系统主入口 - 重构版本
该模块是重构后的记忆系统主入口,使用新的模块化架构。
旧版本入口app/core/memory/src/main.py已删除。
主要功能:
1. 协调整个知识提取流水线
2. 支持试运行模式和正常运行模式
3. 使用重构后的 storage_services 模块
4. 提供统一的配置管理和日志记录
作者Lance77
日期2025-11-22
"""
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
import os
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional, Callable, Awaitable
from dotenv import load_dotenv
# 导入重构后的模块
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
# 导入数据加载函数
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
get_chunked_dialogs_with_preprocessing,
get_chunked_dialogs_from_preprocessed,
)
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.logging_config import get_memory_logger, log_time
load_dotenv()
logger = get_memory_logger(__name__)
async def main(
# Required configuration parameters (no longer from global variables)
chunker_strategy: str,
group_id: str,
user_id: str,
apply_id: str,
llm_model_id: str,
embedding_model_id: str,
# Optional parameters
dialogue_text: Optional[str] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
):
"""
记忆系统主流程 - 重构版本 (Updated to eliminate global variables)
该函数是重构后的主入口,使用新的模块化架构。
Global variables have been eliminated in favor of explicit parameters.
Args:
chunker_strategy: Chunking strategy to use (required)
group_id: Group ID for the operation (required)
user_id: User ID for the operation (required)
apply_id: Application ID for the operation (required)
llm_model_id: LLM model ID to use (required)
embedding_model_id: Embedding model ID to use (required)
dialogue_text: 输入的对话文本(可选,用于试运行模式)
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
progress_callback: 可选的进度回调函数
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
- 在管线关键点调用以报告进度和结果数据
工作流程:
1. 初始化客户端和配置
2. 加载或准备数据
3. 执行知识提取流水线
4. 保存结果(正常模式)或输出结果(试运行模式)
"""
print("=" * 60)
print("MemSci 知识提取流水线 - 重构版本")
print("=" * 60)
print(f"运行模式: {'试运行不保存到Neo4j' if is_pilot_run else '正常运行保存到Neo4j'}")
print("Using chunker strategy:", chunker_strategy)
print("Using group ID:", group_id)
print("Using model ID:", llm_model_id)
print("Using embedding model ID:", embedding_model_id)
print("=" * 60)
# 初始化日志
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
llm_client = get_llm_client(llm_model_id)
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
from app.core.models.base import RedBearModelConfig
embedder_config_dict = get_embedder_config(embedding_model_id)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 加载或准备数据
logger.info("Loading data...")
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
step_start = time.time()
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
# 试运行模式:处理前端传入的对话文本
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
import re
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
# 创建对话上下文和对话数据
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
logger.info("Loading data from testdata.json...")
test_data_path = os.path.join(
os.path.dirname(__file__), "data", "testdata.json"
)
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
indices=None,
input_data_path=test_data_path,
llm_client=llm_client,
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
# 从 runtime.json 加载配置(已经过数据库覆写)
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback, # 传递进度回调
embedding_id=embedding_model_id, # 传递嵌入模型ID
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
# 进度回调:正在知识抽取
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
)
# 解包 extraction_result tuple
# extraction_result 是一个包含 7 个元素的 tuple:
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
# statement_chunk_edges, statement_entity_edges, entity_edges)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 进度回调:生成结果
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 保存结果或输出结果
if is_pilot_run:
logger.info("Pilot run mode: Skipping Neo4j save")
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
print("提取结果已生成,可在相关输出中查看。")
else:
logger.info("Normal mode: Saving to Neo4j...")
step_start = time.time()
# 创建索引和约束
try:
from app.repositories.neo4j.create_indexes import (
create_fulltext_indexes,
create_unique_constraints,
)
await create_fulltext_indexes()
await create_unique_constraints()
logger.info("Successfully created indexes and constraints")
except Exception as e:
logger.error(f"Error creating indexes/constraints: {e}")
# 保存数据到 Neo4j
try:
from app.repositories.neo4j.graph_saver import (
save_dialog_and_statements_to_neo4j,
)
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
entity_edges=entity_edges,
connector=neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
print("\n✓ 成功保存所有数据到 Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
print("\n⚠ 部分数据保存到 Neo4j 失败")
except Exception as e:
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
print(f"\n✗ 保存到 Neo4j 失败: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# 步骤 6: 生成记忆摘要(可选)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import (
add_memory_summary_statement_edges,
)
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
)
if not is_pilot_run:
# 保存记忆摘要到 Neo4j
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
await ms_connector.close()
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
print(f"\n✗ 流水线执行失败: {e}")
raise
finally:
# 清理资源
try:
await neo4j_connector.close()
except Exception:
pass
# 记录总时间
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# 添加完成标记
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
print("\n" + "=" * 60)
print("✓ 流水线执行完成")
print(f"✓ 总耗时: {total_time:.2f}")
print(f"✓ 详细日志: {log_file}")
print("=" * 60)
if __name__ == "__main__":
print("⚠️ Warning: This script now requires explicit configuration parameters.")
print("Global variables have been removed. Please provide configuration parameters.")
print("Example usage:")
print(" asyncio.run(main(")
print(" chunker_strategy='RecursiveChunker',")
print(" group_id='your_group_id',")
print(" user_id='your_user_id',")
print(" apply_id='your_apply_id',")
print(" llm_model_id='your_llm_id',")
print(" embedding_model_id='your_embedding_id'")
print(" ))")
# This will fail because global variables are removed
raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.")

View File

@@ -20,14 +20,14 @@ Classes:
MemorySummaryNode: Node representing a memory summary
"""
from uuid import uuid4
import re
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
import re
from uuid import uuid4
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.utils.alias_utils import validate_aliases
from app.core.memory.utils.data.ontology import TemporalInfo
from pydantic import BaseModel, Field, field_validator
def parse_historical_datetime(v):
@@ -361,7 +361,7 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(..., description="Summary of the fact about this entity")
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -10,9 +10,10 @@ Classes:
"""
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field
class Entity(BaseModel):
"""Represents an extracted entity from dialogue.

View File

@@ -5,7 +5,10 @@ import math
import os
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
@@ -14,16 +17,14 @@ from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
ForgettingEngine,
)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.config.config_utils import (
get_embedder_config,
get_pipeline_config,
)
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.utils.llm.llm_utils import get_reranker_client
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_chunk_id,
@@ -34,6 +35,7 @@ from app.repositories.neo4j.graph_search import (
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from dotenv import load_dotenv
load_dotenv()
@@ -324,229 +326,229 @@ def apply_reranker_placeholder(
If config enables reranker, annotate items with a final_score equal to combined_score
and keep ordering. This is a no-op reranker to be replaced later.
"""
try:
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
except Exception as e:
logger.debug(f"Failed to load reranker config: {e}")
rc = {}
if not rc or not rc.get("enabled", False):
return results
# try:
# rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
# except Exception as e:
# logger.debug(f"Failed to load reranker config: {e}")
# rc = {}
# if not rc or not rc.get("enabled", False):
# return results
top_k = int(rc.get("top_k", 100))
model_name = rc.get("model", "placeholder")
# top_k = int(rc.get("top_k", 100))
# model_name = rc.get("model", "placeholder")
for cat, items in results.items():
head = items[:top_k]
for it in head:
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
it["final_score"] = base
it["reranker_model"] = model_name
# Keep overall order by final_score if present, otherwise combined/score
results[cat] = sorted(
items,
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
reverse=True,
)
# for cat, items in results.items():
# head = items[:top_k]
# for it in head:
# base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
# it["final_score"] = base
# it["reranker_model"] = model_name
# # Keep overall order by final_score if present, otherwise combined/score
# results[cat] = sorted(
# items,
# key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
# reverse=True,
# )
return results
async def apply_llm_reranker(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
reranker_client: Optional[Any] = None,
llm_weight: Optional[float] = None,
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Apply LLM-based reranking to search results.
# async def apply_llm_reranker(
# results: Dict[str, List[Dict[str, Any]]],
# query_text: str,
# reranker_client: Optional[Any] = None,
# llm_weight: Optional[float] = None,
# top_k: Optional[int] = None,
# batch_size: Optional[int] = None,
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Apply LLM-based reranking to search results.
Args:
results: Search results organized by category
query_text: Original search query
reranker_client: Optional pre-initialized reranker client
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
top_k: Maximum number of items to rerank per category
batch_size: Number of items to process concurrently
# Args:
# results: Search results organized by category
# query_text: Original search query
# reranker_client: Optional pre-initialized reranker client
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
# top_k: Maximum number of items to rerank per category
# batch_size: Number of items to process concurrently
Returns:
Reranked results with final_score and reranker_model fields
"""
# Load reranker configuration from runtime.json
try:
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
except Exception as e:
logger.debug(f"Failed to load reranker config: {e}")
rc = {}
# Returns:
# Reranked results with final_score and reranker_model fields
# """
# # Load reranker configuration from runtime.json
# # try:
# # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
# # except Exception as e:
# # logger.debug(f"Failed to load reranker config: {e}")
# # rc = {}
# Check if reranking is enabled
enabled = rc.get("enabled", False)
if not enabled:
logger.debug("LLM reranking is disabled in configuration")
return results
# # Check if reranking is enabled
# enabled = rc.get("enabled", False)
# if not enabled:
# logger.debug("LLM reranking is disabled in configuration")
# return results
# Load configuration parameters with defaults
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
top_k = top_k if top_k is not None else rc.get("top_k", 20)
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# # Load configuration parameters with defaults
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# Initialize reranker client if not provided
if reranker_client is None:
try:
reranker_client = get_reranker_client()
except Exception as e:
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
return results
# # Initialize reranker client if not provided
# if reranker_client is None:
# try:
# reranker_client = get_reranker_client()
# except Exception as e:
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
# return results
# Get model name for metadata
model_name = getattr(reranker_client, 'model_name', 'unknown')
# # Get model name for metadata
# model_name = getattr(reranker_client, 'model_name', 'unknown')
# Process each category
reranked_results = {}
for category in ["statements", "chunks", "entities", "summaries"]:
items = results.get(category, [])
if not items:
reranked_results[category] = []
continue
# # Process each category
# reranked_results = {}
# for category in ["statements", "chunks", "entities", "summaries"]:
# items = results.get(category, [])
# if not items:
# reranked_results[category] = []
# continue
# Select top K items by combined_score for reranking
sorted_items = sorted(
items,
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
reverse=True
)
# # Select top K items by combined_score for reranking
# sorted_items = sorted(
# items,
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
# reverse=True
# )
top_items = sorted_items[:top_k]
remaining_items = sorted_items[top_k:]
# top_items = sorted_items[:top_k]
# remaining_items = sorted_items[top_k:]
# Extract text content from each item
def extract_text(item: Dict[str, Any]) -> str:
"""Extract text content from a result item."""
# Try different text fields based on category
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
return str(text).strip()
# # Extract text content from each item
# def extract_text(item: Dict[str, Any]) -> str:
# """Extract text content from a result item."""
# # Try different text fields based on category
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
# return str(text).strip()
# Batch items for concurrent processing
batches = []
for i in range(0, len(top_items), batch_size):
batch = top_items[i:i + batch_size]
batches.append(batch)
# # Batch items for concurrent processing
# batches = []
# for i in range(0, len(top_items), batch_size):
# batch = top_items[i:i + batch_size]
# batches.append(batch)
# Process batches concurrently
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process a batch of items with LLM relevance scoring."""
scored_batch = []
# # Process batches concurrently
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# """Process a batch of items with LLM relevance scoring."""
# scored_batch = []
for item in batch:
item_text = extract_text(item)
# for item in batch:
# item_text = extract_text(item)
# Skip items with no text
if not item_text:
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = 0.0
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
continue
# # Skip items with no text
# if not item_text:
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item_copy["final_score"] = combined_score
# item_copy["llm_relevance_score"] = 0.0
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# continue
# Create relevance scoring prompt
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
# # Create relevance scoring prompt
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
Query: {query_text}
# Query: {query_text}
Result: {item_text}
# Result: {item_text}
Respond with only a number between 0.0 and 1.0, where:
- 0.0 means completely irrelevant
- 1.0 means perfectly relevant
# Respond with only a number between 0.0 and 1.0, where:
# - 0.0 means completely irrelevant
# - 1.0 means perfectly relevant
Relevance score:"""
# Relevance score:"""
# Send request to LLM
try:
messages = [{"role": "user", "content": prompt}]
response = await reranker_client.chat(messages)
# # Send request to LLM
# try:
# messages = [{"role": "user", "content": prompt}]
# response = await reranker_client.chat(messages)
# Parse LLM response to extract relevance score
response_text = str(response.content if hasattr(response, 'content') else response).strip()
# # Parse LLM response to extract relevance score
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
# Try to extract a float from the response
try:
# Remove any non-numeric characters except decimal point
import re
score_match = re.search(r'(\d+\.?\d*)', response_text)
if score_match:
llm_score = float(score_match.group(1))
# Clamp to [0.0, 1.0]
llm_score = max(0.0, min(1.0, llm_score))
else:
raise ValueError("No numeric score found in response")
except (ValueError, AttributeError) as e:
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
llm_score = None
# # Try to extract a float from the response
# try:
# # Remove any non-numeric characters except decimal point
# import re
# score_match = re.search(r'(\d+\.?\d*)', response_text)
# if score_match:
# llm_score = float(score_match.group(1))
# # Clamp to [0.0, 1.0]
# llm_score = max(0.0, min(1.0, llm_score))
# else:
# raise ValueError("No numeric score found in response")
# except (ValueError, AttributeError) as e:
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
# llm_score = None
# Calculate final score
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# # Calculate final score
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
if llm_score is not None:
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
item_copy["llm_relevance_score"] = llm_score
else:
# Use combined_score as fallback
final_score = combined_score
item_copy["llm_relevance_score"] = combined_score
# if llm_score is not None:
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
# item_copy["llm_relevance_score"] = llm_score
# else:
# # Use combined_score as fallback
# final_score = combined_score
# item_copy["llm_relevance_score"] = combined_score
item_copy["final_score"] = final_score
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
except Exception as e:
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = combined_score
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
# item_copy["final_score"] = final_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# except Exception as e:
# logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item_copy["final_score"] = combined_score
# item_copy["llm_relevance_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
return scored_batch
# return scored_batch
# Process all batches concurrently
try:
batch_tasks = [process_batch(batch) for batch in batches]
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# # Process all batches concurrently
# try:
# batch_tasks = [process_batch(batch) for batch in batches]
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# Merge batch results
scored_items = []
for result in batch_results:
if isinstance(result, Exception):
logger.warning(f"Batch processing failed: {result}")
continue
scored_items.extend(result)
# # Merge batch results
# scored_items = []
# for result in batch_results:
# if isinstance(result, Exception):
# logger.warning(f"Batch processing failed: {result}")
# continue
# scored_items.extend(result)
# Add remaining items (not in top K) with their combined_score as final_score
for item in remaining_items:
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["reranker_model"] = model_name
scored_items.append(item_copy)
# # Add remaining items (not in top K) with their combined_score as final_score
# for item in remaining_items:
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item_copy["final_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_items.append(item_copy)
# Sort all items by final_score in descending order
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
reranked_results[category] = scored_items
# # Sort all items by final_score in descending order
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
# reranked_results[category] = scored_items
except Exception as e:
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# Return original items with combined_score as final_score
for item in items:
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item["final_score"] = combined_score
item["reranker_model"] = model_name
reranked_results[category] = items
# except Exception as e:
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# # Return original items with combined_score as final_score
# for item in items:
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item["final_score"] = combined_score
# item["reranker_model"] = model_name
# reranked_results[category] = items
return reranked_results
# return reranked_results
async def run_hybrid_search(
@@ -556,7 +558,7 @@ async def run_hybrid_search(
limit: int,
include: List[str],
output_path: str | None,
embedding_id: str,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
@@ -564,11 +566,14 @@ async def run_hybrid_search(
"""
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
Args:
memory_config: MemoryConfig object containing embedding_model_id and config_id
"""
# Start overall timing
search_start_time = time.time()
latency_metrics = {}
logger.info(f"using embedding_id:{embedding_id}...")
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
# Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text)
@@ -621,7 +626,9 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
embedder_config_dict = get_embedder_config(embedding_id)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],
@@ -683,7 +690,7 @@ async def run_hybrid_search(
if use_forgetting_rerank:
# Load forgetting parameters from pipeline config
try:
pc = get_pipeline_config()
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
@@ -711,16 +718,16 @@ async def run_hybrid_search(
# Apply LLM reranking if enabled
llm_rerank_applied = False
if use_llm_rerank:
try:
reranked_results = await apply_llm_reranker(
results=reranked_results,
query_text=query_text,
)
llm_rerank_applied = True
logger.info("LLM reranking applied successfully")
except Exception as e:
logger.warning(f"LLM reranking failed: {e}, using previous scores")
# if use_llm_rerank:
# try:
# reranked_results = await apply_llm_reranker(
# results=reranked_results,
# query_text=query_text,
# )
# llm_rerank_applied = True
# logger.info("LLM reranking applied successfully")
# except Exception as e:
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results
results["combined_summary"] = {
@@ -896,90 +903,95 @@ async def search_chunk_by_chunk_id(
return {"chunks": chunks}
def main():
"""Main entry point for the hybrid graph search CLI.
# def main():
# """Main entry point for the hybrid graph search CLI.
Parses command line arguments and executes search with specified parameters.
Supports keyword, embedding, and hybrid search modes.
"""
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
parser.add_argument(
"--query", "-q", required=True, help="Free-text query to search"
)
parser.add_argument(
"--search-type",
"-t",
choices=["keyword", "embedding", "hybrid"],
default="hybrid",
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
)
parser.add_argument(
"--embedding-name",
"-m",
default="openai/nomic-embed-text:v1.5",
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
)
parser.add_argument(
"--group-id",
"-g",
default=None,
help="Optional group_id to filter results (default: None)",
)
parser.add_argument(
"--limit",
"-k",
type=int,
default=5,
help="Max number of results per type (default: 5)",
)
parser.add_argument(
"--include",
"-i",
nargs="+",
default=["statements", "chunks", "entities", "summaries"],
choices=["statements", "chunks", "entities", "summaries"],
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
)
parser.add_argument(
"--output",
"-o",
default="search_results.json",
help="Path to save the search results JSON (default: search_results.json)",
)
parser.add_argument(
"--rerank-alpha",
"-a",
type=float,
default=0.6,
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
)
parser.add_argument(
"--forgetting-rerank",
action="store_true",
help="Apply forgetting curve during reranking for hybrid search.",
)
parser.add_argument(
"--llm-rerank",
action="store_true",
help="Apply LLM-based reranking for hybrid search.",
)
args = parser.parse_args()
# Parses command line arguments and executes search with specified parameters.
# Supports keyword, embedding, and hybrid search modes.
# """
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
# parser.add_argument(
# "--query", "-q", required=True, help="Free-text query to search"
# )
# parser.add_argument(
# "--search-type",
# "-t",
# choices=["keyword", "embedding", "hybrid"],
# default="hybrid",
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
# )
# parser.add_argument(
# "--config-id",
# "-c",
# type=int,
# required=True,
# help="Database configuration ID (required)",
# )
# parser.add_argument(
# "--group-id",
# "-g",
# default=None,
# help="Optional group_id to filter results (default: None)",
# )
# parser.add_argument(
# "--limit",
# "-k",
# type=int,
# default=5,
# help="Max number of results per type (default: 5)",
# )
# parser.add_argument(
# "--include",
# "-i",
# nargs="+",
# default=["statements", "chunks", "entities", "summaries"],
# choices=["statements", "chunks", "entities", "summaries"],
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
# )
# parser.add_argument(
# "--output",
# "-o",
# default="search_results.json",
# help="Path to save the search results JSON (default: search_results.json)",
# )
# parser.add_argument(
# "--rerank-alpha",
# "-a",
# type=float,
# default=0.6,
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
# )
# parser.add_argument(
# "--forgetting-rerank",
# action="store_true",
# help="Apply forgetting curve during reranking for hybrid search.",
# )
# parser.add_argument(
# "--llm-rerank",
# action="store_true",
# help="Apply LLM-based reranking for hybrid search.",
# )
# args = parser.parse_args()
asyncio.run(
run_hybrid_search(
query_text=args.query,
search_type=args.search_type,
group_id=args.group_id,
limit=args.limit,
include=args.include,
output_path=args.output,
embedding_id=config_defs.SELECTED_EMBEDDING_ID,
rerank_alpha=args.rerank_alpha,
use_forgetting_rerank=args.forgetting_rerank,
use_llm_rerank=args.llm_rerank,
)
)
# # Load memory config from database
# from app.services.memory_config_service import MemoryConfigService
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
# asyncio.run(
# run_hybrid_search(
# query_text=args.query,
# search_type=args.search_type,
# group_id=args.group_id,
# limit=args.limit,
# include=args.include,
# output_path=args.output,
# memory_config=memory_config,
# rerank_alpha=args.rerank_alpha,
# use_forgetting_rerank=args.forgetting_rerank,
# use_llm_rerank=args.llm_rerank,
# )
# )
if __name__ == "__main__":
main()
# if __name__ == "__main__":
# main()

View File

@@ -1,19 +1,22 @@
"""
去重功能函数
"""
from app.core.memory.models.variate_config import DedupConfig
from typing import List, Dict, Tuple, Any
from app.core.memory.models.graph_models import(
StatementEntityEdge,
EntityEntityEdge,
ExtractedEntityNode
)
import os
from datetime import datetime
import difflib # 提供字符串相似度计算工具
import asyncio
import difflib # 提供字符串相似度计算工具
import importlib
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.variate_config import DedupConfig
# 模块级类型统一工具函数
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
"""统一实体类型基于LLM建议或启发式规则选择最合适的类型。
@@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
config: DedupConfig,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
"""
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
@@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
"""
llm_records: List[str] = []
try:
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
enable_switch = (
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
)
if not enable_switch:
return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
_defaults = DedupConfig()
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
# 动态导入 llm 客户端(修正导入路径)
try:
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
get_llm_client_fn = llm_utils_mod.get_llm_client
except Exception as e:
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
if not bool(config.enable_llm_dedup_blockwise):
return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数
block_size = config.llm_block_size
block_concurrency = config.llm_block_concurrency
pair_concurrency = config.llm_pair_concurrency
max_rounds = config.llm_max_rounds
try:
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
@@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
return deduped_entities, id_redirect, llm_records
# 获取 LLM 客户端
try:
llm_client = get_llm_client_fn()
if llm_client is None:
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
return deduped_entities, id_redirect, llm_records
except Exception as e:
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
# 验证 LLM 客户端
if llm_client is None:
llm_records.append("[LLM错误] LLM 客户端未提供")
return deduped_entities, id_redirect, llm_records
llm_redirect, llm_records = await llm_fn(
@@ -813,7 +799,8 @@ async def LLM_disamb_decision(
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
config: DedupConfig,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
"""
预消歧阶段对“同名但类型不同”的实体对调用LLM进行消歧
@@ -824,22 +811,16 @@ async def LLM_disamb_decision(
disamb_records: List[str] = []
blocked_pairs: set[tuple[str, str]] = set()
try:
enable_switch = (
config.enable_llm_disambiguation
if config is not None
else DedupConfig().enable_llm_disambiguation
)
if not bool(enable_switch):
if not bool(config.enable_llm_disambiguation):
return deduped_entities, id_redirect, blocked_pairs, disamb_records
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
llm_disambiguate_pairs_iterative,
)
# 获取 LLM 客户端并验证
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 验证 LLM 客户端
if llm_client is None:
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None")
disamb_records.append("[DISAMB错误] LLM 客户端未提供")
return deduped_entities, id_redirect, blocked_pairs, disamb_records
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
@@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges(
report_append: bool = False,
report_stage_notes: List[str] | None = None,
dedup_config: DedupConfig | None = None,
llm_client = None,
) -> Tuple[
List[ExtractedEntityNode],
List[StatementEntityEdge],
@@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges(
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
)
# 2) 模糊匹配(本地规则)
@@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges(
if should_trigger_llm:
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
)
else:
llm_decision_records = []

View File

@@ -10,15 +10,27 @@
from __future__ import annotations
from typing import List, Dict, Any, Tuple
from datetime import datetime
from typing import Any, Dict, List, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.variate_config import DedupConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明
_write_dedup_fusion_report,
deduplicate_entities_and_edges,
)
from app.repositories.neo4j.graph_search import (
get_dedup_candidates_for_entities, # 导入ge函数用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
from app.core.memory.models.variate_config import DedupConfig
from app.repositories.neo4j.neo4j_connector import (
Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于 Neo4j 数据库进行交互
)
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt用于将任意类型的输入值解析为datetime对象处理实体节点中的时间字段
@@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
dedup_config: DedupConfig | None = None,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
第二层去重消歧:
@@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges(
union_entities,
statement_entity_edges,
entity_entity_edges,
report_stage="第二层去重消歧",
report_append=True,
dedup_config=dedup_config,
llm_client=llm_client,
)
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges

View File

@@ -1,23 +1,27 @@
from __future__ import annotations
from typing import List, Tuple, Optional
from typing import List, Optional, Tuple
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementNode,
DialogueNode,
EntityEntityEdge,
ExtractedEntityNode,
StatementChunkEdge,
StatementEntityEdge,
EntityEntityEdge,
StatementNode,
)
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return(
@@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return(
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData],
pipeline_config: Optional[ExtractionPipelineConfig] = None,
pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None,
llm_client = None,
) -> Tuple[
List[DialogueNode],
List[ChunkNode],
@@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return(
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
"""
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
# pipeline_config is required - caller must provide it
if pipeline_config is None:
try:
pipeline_config = get_pipeline_config()
except Exception:
pipeline_config = None
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略
group_id: Optional[str] = None
@@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return(
report_stage="第一层去重消歧",
report_append=False,
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
llm_client=llm_client,
)
# 初始化第二层融合结果为第一层结果
@@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return(
statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges,
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
llm_client=llm_client,
)
else:
print("Skip second-layer dedup: missing connector")

View File

@@ -1253,6 +1253,7 @@ class ExtractionOrchestrator:
report_stage="第一层去重消歧(试运行)",
report_append=False,
dedup_config=self.config.deduplication,
llm_client=self.llm_client,
)
# 保存去重消歧的详细记录到实例变量
@@ -1284,6 +1285,7 @@ class ExtractionOrchestrator:
dialog_data_list,
self.config,
self.connector,
llm_client=self.llm_client,
)
# 解包返回值

View File

@@ -5,11 +5,13 @@
"""
import asyncio
from typing import List, Dict, Any, Tuple
from app.core.memory.models.message_models import DialogData
from typing import Any, Dict, List, Tuple
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
class EmbeddingGenerator:
@@ -21,7 +23,9 @@ class EmbeddingGenerator:
Args:
embedding_id: 嵌入模型 ID
"""
embedder_config = get_embedder_config(embedding_id)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config = config_service.get_embedder_config(embedding_id)
self.embedder_client = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(embedder_config),
)

View File

@@ -1,21 +1,17 @@
import os
import asyncio
from datetime import datetime
from typing import List, Optional
from pydantic import Field, field_validator
from uuid import uuid4
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.base_response import RobustLLMResponse
from app.core.memory.models.graph_models import MemorySummaryNode
from app.core.memory.models.message_models import DialogData
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
from pydantic import Field
logger = get_memory_logger(__name__)
from app.core.memory.models.graph_models import MemorySummaryNode
from app.core.memory.models.base_response import RobustLLMResponse
from app.core.models.base import RedBearModelConfig
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
from uuid import uuid4
class MemorySummaryResponse(RobustLLMResponse):
@@ -91,22 +87,17 @@ async def _process_chunk_summary(
return None
async def Memory_summary_generation(
async def memory_summary_generation(
chunked_dialogs: List[DialogData],
llm_client,
embedding_id,
embedder_client: OpenAIEmbedderClient,
) -> List[MemorySummaryNode]:
"""Generate memory summaries per chunk, embed them, and return nodes."""
embedder_cfg_dict = get_embedder_config(embedding_id)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
)
# Collect all tasks for parallel processing
tasks = []
for dialog in chunked_dialogs:
for chunk in dialog.chunks:
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder))
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
# Process all chunks in parallel
results = await asyncio.gather(*tasks, return_exceptions=False)

View File

@@ -1,17 +1,21 @@
import os
import asyncio
import logging
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS,
RelevenceInfo,
StatementType,
TemporalInfo,
)
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)

View File

@@ -8,21 +8,25 @@
4. 反思结果应用 - 更新记忆库
"""
import asyncio
import json
import logging
import asyncio
import os
import time
from typing import List, Dict, Any, Optional
from enum import Enum
import uuid
from pydantic import BaseModel
from enum import Enum
from typing import Any, Dict, List, Optional
from app.core.response_utils import success
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
from app.repositories.neo4j.neo4j_update import neo4j_data
from app.repositories.neo4j.cypher_queries import (
neo4j_query_all,
neo4j_query_part,
neo4j_statement_all,
neo4j_statement_part,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.neo4j_update import neo4j_data
from pydantic import BaseModel
# 配置日志
_root_logger = logging.getLogger()
@@ -135,14 +139,20 @@ class ReflectionEngine:
self.neo4j_connector = Neo4jConnector()
if self.llm_client is None:
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
model_id = self.llm_client
self.llm_client = get_llm_client(model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(model_id)
if self.get_data_func is None:
from app.core.memory.utils.config.get_data import get_data
@@ -154,11 +164,15 @@ class ReflectionEngine:
self.get_data_statement = get_data_statement
if self.render_evaluate_prompt_func is None:
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.prompt.template_render import (
render_evaluate_prompt,
)
self.render_evaluate_prompt_func = render_evaluate_prompt
if self.render_reflexion_prompt_func is None:
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.core.memory.utils.prompt.template_render import (
render_reflexion_prompt,
)
self.render_reflexion_prompt_func = render_reflexion_prompt
if self.conflict_schema is None:
@@ -170,7 +184,9 @@ class ReflectionEngine:
self.reflexion_schema = ReflexionResultSchema
if self.update_query is None:
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
from app.repositories.neo4j.cypher_queries import (
UPDATE_STATEMENT_INVALID_AT,
)
self.update_query = UPDATE_STATEMENT_INVALID_AT
self._lazy_init_done = True

View File

@@ -4,10 +4,20 @@
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
"""
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.search_strategy import (
SearchResult,
SearchStrategy,
)
from app.core.memory.storage_services.search.semantic_search import (
SemanticSearchStrategy,
)
__all__ = [
"SearchStrategy",
@@ -34,7 +44,7 @@ async def run_hybrid_search(
include: list[str] | None = None,
alpha: float = 0.6,
use_forgetting_curve: bool = False,
embedding_id: str | None = None,
memory_config: "MemoryConfig" = None,
**kwargs
) -> dict:
"""运行混合搜索向后兼容的函数式API
@@ -51,24 +61,26 @@ async def run_hybrid_search(
include: 要包含的搜索类别列表
alpha: BM25分数权重0.0-1.0
use_forgetting_curve: 是否使用遗忘曲线
embedding_id: 嵌入模型ID
memory_config: MemoryConfig object containing embedding_model_id
**kwargs: 其他参数
Returns:
dict: 搜索结果字典格式与旧API兼容
"""
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 使用提供的embedding_id或默认值
emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID
if not memory_config:
raise ValueError("memory_config is required for search")
# 初始化客户端
connector = Neo4jConnector()
embedder_config_dict = get_embedder_config(emb_id)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)

View File

@@ -5,15 +5,20 @@
使用余弦相似度进行语义匹配。
"""
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_memory_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.storage_services.search.search_strategy import (
SearchResult,
SearchStrategy,
)
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
logger = get_memory_logger(__name__)
@@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy):
"""
try:
# 从数据库读取嵌入器配置
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],

View File

@@ -9,7 +9,6 @@ from .config_utils import (
get_chunker_config,
get_embedder_config,
get_model_config,
get_neo4j_config,
get_picture_config,
get_pipeline_config,
get_pruning_config,
@@ -41,7 +40,6 @@ __all__ = [
# config_utils
"get_model_config",
"get_embedder_config",
"get_neo4j_config",
"get_chunker_config",
"get_pipeline_config",
"get_pruning_config",

View File

@@ -1,90 +1,74 @@
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
from app.core.memory.utils.config.definitions import CONFIG
from app.db import get_db
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService
from fastapi import status
from fastapi.exceptions import HTTPException
from sqlalchemy.orm import Session
"""
Configuration utilities - Backward compatibility layer
DEPRECATED: These functions now require a db session parameter.
New code should use MemoryConfigService(db) instance directly.
For functions that don't require db (get_pipeline_config, get_pruning_config),
they are still re-exported here.
"""
import warnings
from app.services.memory_config_service import MemoryConfigService
# These functions don't require db - safe to re-export as static methods
get_pipeline_config = MemoryConfigService.get_pipeline_config
get_pruning_config = MemoryConfigService.get_pruning_config
def get_model_config(model_id: str, db: Session | None = None) -> dict:
def get_model_config(model_id: str, db=None):
"""DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly."""
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen) # 取到真正的 Session
raise ValueError(
"get_model_config now requires a db session. "
"Use MemoryConfigService(db).get_model_config(model_id) directly."
)
return MemoryConfigService(db).get_model_config(model_id)
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
print(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
# 从环境变量读取超时和重试配置
from app.core.config import settings
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
"type": config.type,
# 添加超时和重试配置,避免 LLM 请求超时
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取默认120秒
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取默认2次
}
# 写入model_config.log文件中
with open("logs/model_config.log", "a", encoding="utf-8") as f:
f.write(f"模型ID: {model_id}\n")
f.write(f"模型配置信息:\n{model_config}\n")
f.write("=============================\n\n")
return model_config
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
def get_embedder_config(embedding_id: str, db=None):
"""DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."""
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen) # 取到真正的 Session
raise ValueError(
"get_embedder_config now requires a db session. "
"Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."
)
return MemoryConfigService(db).get_embedder_config(embedding_id)
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
if not config:
print(f"嵌入模型ID {embedding_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
# Ensure required field for RedBearModelConfig validation
"type": config.type,
# 添加超时和重试配置,避免嵌入服务请求超时
"timeout": 120.0, # 嵌入服务超时时间(秒)
"max_retries": 5, # 最大重试次数
}
# 写入embedder_config.log文件中
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
f.write(f"嵌入模型ID: {embedding_id}\n")
f.write(f"嵌入模型配置信息:\n{model_config}\n")
f.write("=============================\n\n")
return model_config
def get_neo4j_config() -> dict:
"""Retrieves the Neo4j configuration from the config file."""
return CONFIG.get("neo4j", {})
def get_picture_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file."""
"""Retrieves the configuration for a specific model from the config file.
.. deprecated::
This function is deprecated and will be removed in a future version.
Use database-backed model configuration instead.
"""
warnings.warn(
"get_picture_config is deprecated and will be removed in a future version. "
"Use database-backed model configuration instead.",
DeprecationWarning,
stacklevel=2
)
for model_config in CONFIG.get("picture_recognition", []):
if model_config["llm_name"] == llm_name:
return model_config
raise ValueError(f"Model '{llm_name}' not found in config.json")
def get_voice_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file."""
"""Retrieves the configuration for a specific model from the config file.
.. deprecated::
This function is deprecated and will be removed in a future version.
Use database-backed model configuration instead.
"""
warnings.warn(
"get_voice_config is deprecated and will be removed in a future version. "
"Use database-backed model configuration instead.",
DeprecationWarning,
stacklevel=2
)
for model_config in CONFIG.get("voice_recognition", []):
if model_config["llm_name"] == llm_name:
return model_config
@@ -92,19 +76,8 @@ def get_voice_config(llm_name: str) -> dict:
def get_chunker_config(chunker_strategy: str) -> dict:
"""Retrieves the configuration for a specific chunker strategy.
"""Retrieves the configuration for a specific chunker strategy."""
Enhancements:
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
- Falls back to the first available chunker config when the requested one is missing.
"""
# 1) Try to find exact match in config
chunker_list = CONFIG.get("chunker_list", [])
for chunker_config in chunker_list:
if chunker_config.get("chunker_strategy") == chunker_strategy:
return chunker_config
# 2) Provide sane defaults for newer strategies
default_configs = {
"RecursiveChunker": {
"chunker_strategy": "RecursiveChunker",
@@ -112,7 +85,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
"chunk_size": 512,
"min_characters_per_chunk": 50
},
"LLMChunker": {
"chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3",
@@ -137,127 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
if chunker_strategy in default_configs:
return default_configs[chunker_strategy]
# 3) Fallback: use first available config but tag with requested strategy
if chunker_list:
fallback = chunker_list[0].copy()
fallback["chunker_strategy"] = chunker_strategy
# Non-fatal notice for visibility in logs if any
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
return fallback
# 4) If no configs available at all
raise ValueError(
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
f"Chunker '{chunker_strategy}' not found "
)
#TODO: Fix this
def get_pipeline_config(
config_id: int,
db: Session | None = None,
) -> ExtractionPipelineConfig:
"""Build ExtractionPipelineConfig from database.
Args:
config_id: Database configuration ID (required). Loads pipeline
settings from the DataConfig table.
db: Optional database session. If not provided, a new session
will be created.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings loaded from database.
Raises:
ValueError: If config_id not found in database.
"""
from app.repositories.data_config_repository import DataConfigRepository
# Load from database
if db is None:
db_gen = get_db()
db = next(db_gen)
db_config = DataConfigRepository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"Configuration {config_id} not found in database")
# Build DedupConfig from database
dedup_kwargs = {
"enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False,
"enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False,
}
# Fuzzy thresholds
if db_config.t_name_strict is not None:
dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict
if db_config.t_type_strict is not None:
dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict
if db_config.t_overall is not None:
dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall
dedup_config = DedupConfig(**dedup_kwargs)
# Build StatementExtractionConfig from database
stmt_kwargs = {}
if db_config.statement_granularity is not None:
stmt_kwargs["statement_granularity"] = db_config.statement_granularity
if db_config.include_dialogue_context is not None:
stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context)
if db_config.max_context is not None:
stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context
stmt_config = StatementExtractionConfig(**stmt_kwargs)
# Build ForgettingEngineConfig from database
forget_kwargs = {}
if db_config.offset is not None:
forget_kwargs["offset"] = db_config.offset
if db_config.lambda_time is not None:
forget_kwargs["lambda_time"] = db_config.lambda_time
if db_config.lambda_mem is not None:
forget_kwargs["lambda_mem"] = db_config.lambda_mem
forget_config = ForgettingEngineConfig(**forget_kwargs)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
def get_pruning_config(
config_id: int,
db: Session | None = None,
) -> dict:
"""Retrieve semantic pruning config from database.
Args:
config_id: Database configuration ID (required).
db: Optional database session.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str ("education" | "online_service" | "outbound")
- pruning_threshold: float (0-0.9)
Raises:
ValueError: If config_id not found in database.
"""
from app.repositories.data_config_repository import DataConfigRepository
if db is None:
db_gen = get_db()
db = next(db_gen)
db_config = DataConfigRepository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"Configuration {config_id} not found in database")
return {
"pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False,
"pruning_scene": db_config.pruning_scene or "education",
"pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5,
}

View File

@@ -1,269 +1,268 @@
"""
配置加载模块 - DEPRECATED
# """
# 配置加载模块 - DEPRECATED
⚠️ DEPRECATION NOTICE ⚠️
This module is deprecated and will be removed in a future version.
Global configuration variables have been eliminated in favor of dependency injection.
# ⚠️ DEPRECATION NOTICE ⚠️
# This module is deprecated and will be removed in a future version.
# Global configuration variables have been eliminated in favor of dependency injection.
Use the new MemoryConfig system instead:
- app.core.memory_config.config.MemoryConfig for configuration objects
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# Use the new MemoryConfig system instead:
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED
阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
"""
import json
import os
import threading
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
# 阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED
# 阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
# """
# import json
# import os
# import threading
# from datetime import datetime, timedelta
# from typing import Any, Dict, Optional
#TODO: Fix this
# #TODO: Fix this
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
# try:
# from dotenv import load_dotenv
# load_dotenv()
# except Exception:
# pass
# Import unified configuration system
try:
from app.core.config import settings
USE_UNIFIED_CONFIG = True
except ImportError:
USE_UNIFIED_CONFIG = False
settings = None
# # Import unified configuration system
# try:
# from app.core.config import settings
# USE_UNIFIED_CONFIG = True
# except ImportError:
# USE_UNIFIED_CONFIG = False
# settings = None
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
# __file__ = app/core/memory/utils/config/definitions.py
# os.path.dirname(__file__) = app/core/memory/utils/config
# os.path.dirname(...) = app/core/memory/utils
# os.path.dirname(...) = app/core/memory
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
# # __file__ = app/core/memory/utils/config/definitions.py
# # os.path.dirname(__file__) = app/core/memory/utils/config
# # os.path.dirname(...) = app/core/memory/utils
# # os.path.dirname(...) = app/core/memory
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# DEPRECATED: Global configuration lock removed
# Use MemoryConfig objects with dependency injection instead
# # DEPRECATED: Global configuration lock removed
# # Use MemoryConfig objects with dependency injection instead
# DEPRECATED: Legacy config.json loading removed
# Use MemoryConfig objects with dependency injection instead
CONFIG = {}
# # DEPRECATED: Legacy config.json loading removed
# # Use MemoryConfig objects with dependency injection instead
# CONFIG = {}
DEFAULT_VALUES = {
"llm_name": "openai/qwen-plus",
"embedding_name": "openai/nomic-embed-text:v1.5",
"chunker_strategy": "RecursiveChunker",
"group_id": "group_123",
"user_id": "default_user",
"apply_id": "default_apply",
"llm_agent_name": "openai/qwen-plus",
"llm_verify_name": "openai/qwen-plus",
"llm_image_recognition": "openai/qwen-plus",
"llm_voice_recognition": "openai/qwen-plus",
"prompt_level": "DEBUG",
"reflexion_iteration_period": "3",
"reflexion_range": "retrieval",
"reflexion_baseline": "TIME",
}
# DEFAULT_VALUES = {
# "llm_name": "openai/qwen-plus",
# "embedding_name": "openai/nomic-embed-text:v1.5",
# "chunker_strategy": "RecursiveChunker",
# "group_id": "group_123",
# "user_id": "default_user",
# "apply_id": "default_apply",
# "llm_agent_name": "openai/qwen-plus",
# "llm_verify_name": "openai/qwen-plus",
# "llm_image_recognition": "openai/qwen-plus",
# "llm_voice_recognition": "openai/qwen-plus",
# "prompt_level": "DEBUG",
# "reflexion_iteration_period": "3",
# "reflexion_range": "retrieval",
# "reflexion_baseline": "TIME",
# }
# DEPRECATED: Legacy global variables for backward compatibility only
# These will be removed in a future version
# Use MemoryConfig objects with dependency injection instead
LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
# # DEPRECATED: Legacy global variables for backward compatibility only
# # These will be removed in a future version
# # Use MemoryConfig objects with dependency injection instead
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
# 阶段 1: 从 runtime.json 加载配置(路径 A
def _load_from_runtime_json() -> Dict[str, Any]:
"""
DEPRECATED: Legacy runtime.json loading
# # 阶段 1: 从 runtime.json 加载配置(路径 A
# def _load_from_runtime_json() -> Dict[str, Any]:
# """
# DEPRECATED: Legacy runtime.json loading
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
Returns:
Dict[str, Any]: Empty configuration (legacy support only)
"""
import warnings
warnings.warn(
"Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
return {"selections": {}}
# Returns:
# Dict[str, Any]: Empty configuration (legacy support only)
# """
# import warnings
# warnings.warn(
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return {"selections": {}}
# 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# 保留此函数仅为向后兼容
def _load_from_database() -> Optional[Dict[str, Any]]:
"""
DEPRECATED: Legacy database configuration loading
# # 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# # 保留此函数仅为向后兼容
# def _load_from_database() -> Optional[Dict[str, Any]]:
# """
# DEPRECATED: Legacy database configuration loading
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
Returns:
Optional[Dict[str, Any]]: None (deprecated functionality)
"""
import warnings
warnings.warn(
"Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
return None
# Returns:
# Optional[Dict[str, Any]]: None (deprecated functionality)
# """
# import warnings
# warnings.warn(
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return None
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
"""
DEPRECATED: 将运行时配置暴露为全局常量供项目使用
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
# """
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
⚠️ This function is deprecated and will be removed in a future version.
Global configuration variables have been eliminated in favor of dependency injection.
# ⚠️ This function is deprecated and will be removed in a future version.
# Global configuration variables have been eliminated in favor of dependency injection.
Use the new MemoryConfig system instead:
- app.core.memory_config.config.MemoryConfig for configuration objects
- Pass configuration objects as parameters instead of using global variables
# Use the new MemoryConfig system instead:
# - app.core.memory_config.config.MemoryConfig for configuration objects
# - Pass configuration objects as parameters instead of using global variables
Args:
runtime_cfg: 运行时配置字典
"""
import warnings
warnings.warn(
"Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
# Args:
# runtime_cfg: 运行时配置字典
# """
# import warnings
# warnings.warn(
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# Keep minimal global state for backward compatibility only
# These will be removed in a future version
global RUNTIME_CONFIG, SELECTIONS
# # Keep minimal global state for backward compatibility only
# # These will be removed in a future version
# global RUNTIME_CONFIG, SELECTIONS
RUNTIME_CONFIG = runtime_cfg
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# RUNTIME_CONFIG = runtime_cfg
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# All other global variables have been removed
# Use MemoryConfig objects instead
# # All other global variables have been removed
# # Use MemoryConfig objects instead
# 初始化:使用统一配置加载器
def _initialize_configuration() -> None:
"""
DEPRECATED: Legacy configuration initialization
# # 初始化:使用统一配置加载器
# def _initialize_configuration() -> None:
# """
# DEPRECATED: Legacy configuration initialization
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
"""
import warnings
warnings.warn(
"Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
# Initialize with empty configuration for backward compatibility
_expose_runtime_constants({"selections": {}})
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# """
# import warnings
# warnings.warn(
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Initialize with empty configuration for backward compatibility
# _expose_runtime_constants({"selections": {}})
# 模块加载时自动初始化配置
_initialize_configuration()
# # 模块加载时自动初始化配置
# _initialize_configuration()
# DEPRECATED: Global variables removed
# These variables have been eliminated in favor of dependency injection
# Use MemoryConfig objects instead of accessing global variables
# # DEPRECATED: Global variables removed
# # These variables have been eliminated in favor of dependency injection
# # Use MemoryConfig objects instead of accessing global variables
# 公共 API动态重新加载配置
def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
"""
DEPRECATED: Legacy configuration reloading
# # 公共 API动态重新加载配置
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
# """
# DEPRECATED: Legacy configuration reloading
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
For new code, use:
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# For new code, use:
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
Args:
config_id: Configuration ID (deprecated)
force_reload: Force reload flag (deprecated)
# Args:
# config_id: Configuration ID (deprecated)
# force_reload: Force reload flag (deprecated)
Returns:
bool: Always returns False (deprecated functionality)
"""
import logging
import warnings
# Returns:
# bool: Always returns False (deprecated functionality)
# """
# import logging
# import warnings
logger = logging.getLogger(__name__)
# logger = logging.getLogger(__name__)
warnings.warn(
"reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
# warnings.warn(
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
"Use MemoryConfig objects with dependency injection instead.")
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
# "Use MemoryConfig objects with dependency injection instead.")
return False
# return False
def get_current_config_id() -> Optional[str]:
"""
DEPRECATED: Legacy config ID retrieval
# def get_current_config_id() -> Optional[str]:
# """
# DEPRECATED: Legacy config ID retrieval
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
Returns:
Optional[str]: None (deprecated functionality)
"""
import warnings
warnings.warn(
"get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
return None
# Returns:
# Optional[str]: None (deprecated functionality)
# """
# import warnings
# warnings.warn(
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return None
def ensure_fresh_config(config_id = None) -> bool:
"""
DEPRECATED: Legacy configuration freshness check
# def ensure_fresh_config(config_id = None) -> bool:
# """
# DEPRECATED: Legacy configuration freshness check
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
For new code, use:
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# For new code, use:
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
Args:
config_id: Configuration ID (deprecated)
# Args:
# config_id: Configuration ID (deprecated)
Returns:
bool: Always returns False (deprecated functionality)
"""
import logging
import warnings
# Returns:
# bool: Always returns False (deprecated functionality)
# """
# import logging
# import warnings
logger = logging.getLogger(__name__)
# logger = logging.getLogger(__name__)
warnings.warn(
"ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
# warnings.warn(
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
"Use MemoryConfig objects with dependency injection instead.")
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
# "Use MemoryConfig objects with dependency injection instead.")
return False
# return False

View File

@@ -6,8 +6,9 @@ This module provides centralized functions for creating embedder clients.
from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
@@ -66,7 +67,8 @@ def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
raise ValueError("Embedding ID is required but was not provided")
try:
embedder_config_dict = get_embedder_config(embedding_id)
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e

View File

@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.models.base import RedBearModelConfig
from pydantic import BaseModel
from sqlalchemy.orm import Session
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
@@ -13,105 +13,225 @@ async def handle_response(response: type[BaseModel]) -> dict:
return response.model_dump()
def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient:
class MemoryClientFactory:
"""
Get LLM client from MemoryConfig object.
Factory for creating LLM, embedder, and reranker clients.
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
This ensures proper configuration management and multi-tenant support.
Initialize once with db session, then call methods without passing db each time.
Args:
memory_config: MemoryConfig object containing llm_model_id
Returns:
OpenAIClient: Initialized LLM client
Raises:
ValueError: If LLM model ID is not configured or client initialization fails
Example:
>>> llm_client = get_llm_client_from_config(memory_config)
>>> factory = MemoryClientFactory(db)
>>> llm_client = factory.get_llm_client(model_id)
>>> embedder_client = factory.get_embedder_client(embedding_id)
"""
if not memory_config.llm_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no LLM model configured"
)
return get_llm_client(str(memory_config.llm_model_id))
def __init__(self, db: Session):
from app.services.memory_config_service import MemoryConfigService
self._config_service = MemoryConfigService(db)
def get_llm_client(self, llm_id: str) -> OpenAIClient:
"""Get LLM client by model ID."""
if not llm_id:
raise ValueError("LLM ID is required")
try:
model_config = self._config_service.get_model_config(llm_id)
except Exception as e:
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),
type_=model_config.get("type")
)
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_embedder_client(self, embedding_id: str):
"""Get embedder client by model ID."""
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
if not embedding_id:
raise ValueError("Embedding ID is required")
try:
embedder_config = self._config_service.get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
try:
return OpenAIEmbedderClient(
RedBearModelConfig(
model_name=embedder_config.get("model_name"),
provider=embedder_config.get("provider"),
api_key=embedder_config.get("api_key"),
base_url=embedder_config.get("base_url")
)
)
except Exception as e:
model_name = embedder_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
"""Get reranker client by model ID."""
if not rerank_id:
raise ValueError("Rerank ID is required")
try:
model_config = self._config_service.get_model_config(rerank_id)
except Exception as e:
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),
type_=model_config.get("type")
)
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
"""Get LLM client from MemoryConfig object.
Args:
memory_config: Configuration containing llm_model_id
Returns:
OpenAIClient configured for the LLM model
Raises:
ValueError: If memory_config has no LLM model configured
"""
if not memory_config.llm_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no LLM model configured"
)
return self.get_llm_client(str(memory_config.llm_model_id))
def get_embedder_client_from_config(self, memory_config: "MemoryConfig"):
"""Get embedder client from MemoryConfig object.
Args:
memory_config: Configuration containing embedding_model_id
Returns:
OpenAIEmbedderClient configured for the embedding model
Raises:
ValueError: If memory_config has no embedding model configured
"""
if not memory_config.embedding_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no embedding model configured"
)
return self.get_embedder_client(str(memory_config.embedding_model_id))
def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
"""Get reranker client from MemoryConfig object.
Args:
memory_config: Configuration containing rerank_model_id
Returns:
OpenAIClient configured for the reranker model
Raises:
ValueError: If memory_config has no rerank model configured
"""
if not memory_config.rerank_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no rerank model configured"
)
return self.get_reranker_client(str(memory_config.rerank_model_id))
def get_llm_client(llm_id: str):
"""
Get LLM client by model ID.
# Legacy functions for backward compatibility
def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient:
"""Get LLM client from MemoryConfig object.
**LEGACY/TEST METHOD**: Use this function only for:
- Test/evaluation code where you have a model ID directly
- Legacy code that hasn't been migrated to MemoryConfig yet
DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead.
For production code with MemoryConfig, use get_llm_client_from_config() instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_llm_client_from_config method directly.
Args:
llm_id: LLM model ID (required)
memory_config: Configuration containing llm_model_id
db: Database session
Returns:
OpenAIClient: Initialized LLM client
OpenAIClient configured for the LLM model
Raises:
ValueError: If llm_id is not provided or client initialization fails
Example:
>>> # For tests/evaluations only
>>> llm_client = get_llm_client("model-uuid-string")
ValueError: If memory_config has no LLM model configured
"""
if not llm_id:
raise ValueError("LLM ID is required but was not provided")
try:
model_config = get_model_config(llm_id)
except Exception as e:
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
llm_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
return llm_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
return MemoryClientFactory(db).get_llm_client_from_config(memory_config)
def get_reranker_client(rerank_id: str):
"""
Get an LLM client configured for reranking.
def get_llm_client(llm_id: str, db: Session) -> OpenAIClient:
"""Get LLM client by model ID.
DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_llm_client method directly.
Args:
rerank_id: Reranker model ID (required)
llm_id: LLM model ID
db: Database session
Returns:
OpenAIClient: Initialized client for the reranker model
Raises:
ValueError: If rerank_id is not provided or client initialization fails
OpenAIClient configured for the LLM model
"""
if not rerank_id:
raise ValueError("Rerank ID is required but was not provided")
return MemoryClientFactory(db).get_llm_client(llm_id)
def get_embedder_client(embedding_id: str, db: Session):
"""Get embedder client by model ID.
try:
model_config = get_model_config(rerank_id)
except Exception as e:
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead.
try:
reranker_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
return reranker_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_embedder_client method directly.
Args:
embedding_id: Embedding model ID
db: Database session
Returns:
OpenAIEmbedderClient configured for the embedding model
"""
return MemoryClientFactory(db).get_embedder_client(embedding_id)
def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient:
"""Get reranker client by model ID.
DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_reranker_client method directly.
Args:
rerank_id: Reranker model ID
db: Database session
Returns:
OpenAIClient configured for the reranker model
"""
return MemoryClientFactory(db).get_reranker_client(rerank_id)

View File

@@ -6,11 +6,12 @@
"""
import logging
from typing import List, Any
import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db_context
from app.schemas.memory_storage_schema import ConflictResultSchema
from pydantic import BaseModel
@@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]:
冲突记忆列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")

View File

@@ -6,11 +6,12 @@
"""
import logging
from typing import List, Any
import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db_context
from app.schemas.memory_storage_schema import ReflexionResultSchema
from pydantic import BaseModel
@@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]:
反思结果列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")

View File

@@ -5,16 +5,24 @@ This module provides functionality to analyze chunk content and generate insight
"""
import asyncio
from typing import List, Dict, Any
from collections import Counter
from typing import Any, Dict, List
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkInsight(BaseModel):
"""Pydantic model for chunk insight."""
insight: str = Field(..., description="对chunk内容的深度洞察分析")
@@ -40,7 +48,7 @@ async def classify_chunk_domain(chunk: str) -> str:
Domain name
"""
try:
llm_client = get_llm_client()
llm_client = _get_llm_client()
prompt = f"""请将以下文本内容归类到最合适的领域中。
@@ -177,7 +185,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
]
# 调用LLM生成洞察
llm_client = get_llm_client()
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
insight = response.content.strip()

View File

@@ -5,15 +5,23 @@ This module provides functionality to summarize chunk content using LLM.
"""
import asyncio
from typing import List, Dict, Any
from typing import Any, Dict, List
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkSummary(BaseModel):
"""Pydantic model for chunk summary."""
summary: str = Field(..., description="简洁的chunk内容摘要")
@@ -59,7 +67,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
]
# 调用LLM生成摘要
llm_client = get_llm_client()
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
summary = response.content.strip()

View File

@@ -7,14 +7,22 @@ This module provides functionality to extract meaningful tags from chunk content
import asyncio
from collections import Counter
from typing import List, Tuple
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ExtractedTags(BaseModel):
"""Pydantic model for extracted tags."""
tags: List[str] = Field(..., description="从文本中提取的关键标签列表")
@@ -56,7 +64,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
)
llm_client = get_llm_client()
llm_client = _get_llm_client()
# 为每个chunk单独提取标签然后统计频率
all_tags = []
@@ -151,7 +159,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
]
# 调用LLM提取人物形象
llm_client = get_llm_client()
llm_client = _get_llm_client()
structured_response = await llm_client.response_structured(
messages=messages,
response_model=ExtractedPersona

View File

@@ -391,6 +391,29 @@ class MemoryConfig:
embedding_params: Dict[str, Any] = field(default_factory=dict)
config_version: str = "2.0"
# Pipeline config: Deduplication
enable_llm_dedup_blockwise: bool = False
enable_llm_disambiguation: bool = False
deep_retrieval: bool = True
t_type_strict: float = 0.8
t_name_strict: float = 0.8
t_overall: float = 0.8
# Pipeline config: Statement extraction
statement_granularity: int = 2
include_dialogue_context: bool = False
max_dialogue_context_chars: int = 1000
# Pipeline config: Forgetting engine
lambda_time: float = 0.5
lambda_mem: float = 0.5
offset: float = 0.0
# Pipeline config: Pruning
pruning_enabled: bool = False
pruning_scene: Optional[str] = "education"
pruning_threshold: float = 0.5
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.config_name or not self.config_name.strip():

View File

@@ -3,26 +3,26 @@
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
"""
import time
import uuid
import json
import asyncio
import datetime
from typing import Dict, Any, Optional, List, AsyncGenerator
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy import select
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.models import AgentConfig, ModelConfig, ModelApiKey
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.core.rag.nlp.search import knowledge_retrieval
from app.services.langchain_tool_server import Search
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
from app.db import get_db
db = next(get_db())
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
)
)
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
@@ -713,9 +719,9 @@ class DraftRunService:
Raises:
BusinessException: 当指定的会话不存在时
"""
from app.services.conversation_service import ConversationService
from app.schemas.conversation_schema import ConversationCreate
from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)

View File

@@ -7,14 +7,15 @@ Classes:
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
"""
from typing import Dict, Any, Optional, List
import statistics
import json
from pydantic import BaseModel, Field
import statistics
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.repositories.neo4j.emotion_repository import EmotionRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_business_logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_business_logger()
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
async def generate_emotion_suggestions(
self,
end_user_id: str,
config_id: Optional[int] = None
db: Session,
) -> Dict[str, Any]:
"""生成个性化情绪建议
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
Args:
end_user_id: 宿主ID用户组ID
config_id: 配置ID可选用于从数据库加载LLM配置
db: 数据库会话
Returns:
Dict: 包含个性化建议的响应:
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
- suggestions: 建议列表3-5条
"""
try:
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 如果提供了 config_id从数据库加载配置
if config_id is not None:
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
# 1. 从 end_user_id 获取关联的 memory_config_id
llm_client = None
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=int(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.client_factory import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM
from app.core.memory.utils.llm.llm_utils import get_llm_client
llm_client = get_llm_client()
if llm_client is None:
# 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式
messages = [
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
Returns:
str: LLM prompt
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt(
health_data=health_data,

View File

@@ -9,10 +9,12 @@ Classes:
import logging
from typing import Optional
from app.core.memory.models.emotion_models import EmotionExtraction
from app.models.data_config_model import DataConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.models.emotion_models import EmotionExtraction
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.data_config_model import DataConfig
logger = logging.getLogger(__name__)
@@ -50,7 +52,9 @@ class EmotionExtractionService:
"""
if self.llm_client is None or model_id:
effective_model_id = model_id or self.llm_id
self.llm_client = get_llm_client(effective_model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(effective_model_id)
return self.llm_client
async def extract_emotion(
@@ -142,7 +146,9 @@ class EmotionExtractionService:
Returns:
Formatted prompt string for LLM
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_extraction_prompt,
)
prompt = await render_emotion_extraction_prompt(
statement=statement,

View File

@@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
@@ -45,8 +45,7 @@ config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
db_gen = get_db()
db = next(db_gen)
class MemoryAgentService:
"""Service for memory agent operations"""
@@ -55,27 +54,6 @@ class MemoryAgentService:
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def load_memory_config(self, config_id: int) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0]
@@ -277,14 +255,17 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Write operation result status
@@ -292,14 +273,24 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
import time
start_time = time.time()
# Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -366,6 +357,7 @@ class MemoryAgentService:
history: List[Dict],
search_switch: str,
config_id: str,
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
@@ -378,11 +370,14 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Dict with 'answer' and 'intermediate_outputs' keys
@@ -394,8 +389,13 @@ class MemoryAgentService:
import time
start_time = time.time()
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
@@ -411,7 +411,11 @@ class MemoryAgentService:
with group_lock:
# Step 1: Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -696,7 +700,11 @@ class MemoryAgentService:
logger.info("Classifying message type")
# Load configuration to get LLM model ID
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
@@ -865,7 +873,8 @@ class MemoryAgentService:
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -877,6 +886,7 @@ class MemoryAgentService:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式:
{
@@ -893,7 +903,7 @@ class MemoryAgentService:
# 1. 根据 end_user_id 获取 end_user_name
try:
if end_user_id:
if end_user_id and db:
from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema
@@ -948,7 +958,9 @@ class MemoryAgentService:
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
llm_client = get_llm_client(llm_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):
@@ -1110,4 +1122,69 @@ class MemoryAgentService:
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取该应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result

View File

@@ -3,7 +3,6 @@ Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
Database session management is handled internally.
"""
import time
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
invalid_value=config_id,
)
return parsed_id
except ValueError as e:
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
@@ -77,19 +76,29 @@ class MemoryConfigService:
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Database session management is handled internally.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
@staticmethod
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: int,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method manages its own database session internally.
Args:
config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes)
@@ -100,27 +109,6 @@ class MemoryConfigService:
Raises:
ConfigurationError: If validation fails
"""
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
try:
return MemoryConfigService._load_memory_config_with_db(
config_id=config_id,
db=db,
service_name=service_name,
)
finally:
db.close()
@staticmethod
def _load_memory_config_with_db(
config_id: int,
db: Session,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""Internal method that loads memory configuration with an existing db session."""
start_time = time.time()
config_logger.info(
@@ -137,7 +125,7 @@ class MemoryConfigService:
try:
validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -160,7 +148,7 @@ class MemoryConfigService:
embedding_uuid = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
db,
self.db,
workspace.tenant_id,
workspace.id,
)
@@ -169,7 +157,7 @@ class MemoryConfigService:
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
db,
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
@@ -183,7 +171,7 @@ class MemoryConfigService:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
db,
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
@@ -194,7 +182,7 @@ class MemoryConfigService:
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
db,
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
@@ -220,6 +208,25 @@ class MemoryConfigService:
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
)
elapsed_ms = (time.time() - start_time) * 1000
@@ -262,3 +269,131 @@ class MemoryConfigService:
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}

View File

@@ -49,27 +49,6 @@ class MemoryStorageService:
def __init__(self):
logger.info("MemoryStorageService initialized")
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryStorageService"
)
async def get_storage_info(self) -> dict:
"""
@@ -293,7 +272,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# Load configuration from database only using centralized manager
try:
memory_config = MemoryConfigService.load_memory_config(
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(
config_id=int(cid),
service_name="MemoryStorageService.pilot_run_stream"
)
@@ -320,13 +300,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.core.memory.main import main as pipeline_main
from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(
dialogue_text=dialogue_text,
is_pilot_run=True,
progress_callback=progress_callback
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text,
db=self.db,
progress_callback=progress_callback,
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")

View File

@@ -0,0 +1,219 @@
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")

View File

@@ -176,7 +176,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"""Celery task to process a read message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent
group_id: Group ID for the memory agent (also used as end_user_id)
message: User message to process
history: Conversation history
search_switch: Search switch parameter
@@ -190,9 +190,28 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"""
start_time = time.time()
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception as e:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id)
db = next(get_db())
try:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -246,7 +265,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"""Celery task to process a write message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent
group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write
config_id: Optional configuration ID
@@ -258,9 +277,28 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"""
start_time = time.time()
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception as e:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
service = MemoryAgentService()
return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id)
db = next(get_db())
try:
service = MemoryAgentService()
return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突