refactor(memory): restructure memory system and improve configuration management
- Remove deprecated main.py entry point from memory module - Reorganize imports across controllers and services for consistency - Update emotion controller to pass db session instead of config_id to services - Enhance memory agent controller with db session parameter for status_type and user_profile endpoints - Refactor memory agent service to accept db parameter in classify_message_type method - Improve configuration handling in celery_app by removing automatic database reload - Update all memory-related services to use centralized config management - Standardize import ordering and remove unused imports across 50+ files - Add pilot_run_service for new pilot execution workflow - Refactor extraction engine, reflection engine, and search services for better modularity - Update LLM utilities and embedder configuration for improved flexibility - Enhance type classifier and verification tools with better error handling - Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from celery import Celery
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
@@ -13,7 +13,6 @@ celery_app = Celery(
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
)
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
|
||||
# 配置使用本地队列,避免与远程 worker 冲突
|
||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||
@@ -22,6 +21,7 @@ celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
||||
|
||||
# macOS 兼容性配置
|
||||
import platform
|
||||
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
# 设置环境变量解决 fork 问题
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
@@ -10,22 +10,21 @@ Routes:
|
||||
POST /emotion/suggestions - 获取个性化情绪建议
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.emotion_schema import (
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest,
|
||||
EmotionTagsRequest,
|
||||
EmotionWordcloudRequest,
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -230,7 +229,7 @@ async def get_emotion_suggestions(
|
||||
# 调用服务层
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
config_id=config_id
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
|
||||
@@ -163,7 +163,8 @@ async def write_server(
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
config_id,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
@@ -280,6 +281,7 @@ async def read_server(
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
@@ -548,6 +550,7 @@ async def get_write_task_result(
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
@@ -561,7 +564,11 @@ async def status_type(
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
try:
|
||||
result = await memory_agent_service.classify_message_type(user_input.message)
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
user_input.message,
|
||||
user_input.config_id,
|
||||
db
|
||||
)
|
||||
return success(data=result)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Message type classification failed: {str(e)}")
|
||||
@@ -636,6 +643,7 @@ async def get_hot_memory_tags_by_user_api(
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
@@ -659,7 +667,8 @@ async def get_user_profile_api(
|
||||
try:
|
||||
result = await memory_agent_service.get_user_profile(
|
||||
end_user_id=end_user_id,
|
||||
current_user_id=str(current_user.id)
|
||||
current_user_id=str(current_user.id),
|
||||
db=db
|
||||
)
|
||||
return success(data=result, msg="获取用户详情成功")
|
||||
except Exception as e:
|
||||
@@ -694,4 +703,41 @@ async def get_user_profile_api(
|
||||
# )
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"API docs retrieval failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
|
||||
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
|
||||
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
|
||||
通过以下流程获取配置:
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的响应
|
||||
"""
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"End user config not found: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
@@ -1,22 +1,27 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
ReflectionEngine,
|
||||
)
|
||||
from app.core.response_utils import success
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
|
||||
from app.dependencies import get_current_user
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
MemoryReflectionService,
|
||||
WorkspaceAppService,
|
||||
)
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -9,18 +9,19 @@ LangChain Agent 封装
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -198,10 +199,24 @@ class LangChainAgent:
|
||||
"""
|
||||
message_chat= message
|
||||
start_time = time.time()
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
actual_config_id = config_id
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
@@ -295,10 +310,24 @@ class LangChainAgent:
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
actual_config_id = config_id
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
@@ -81,6 +82,7 @@ class Settings:
|
||||
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
|
||||
|
||||
# Langfuse configuration
|
||||
LANGFUSE_ENABLED: bool = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
@@ -153,9 +155,6 @@ class Settings:
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
|
||||
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
|
||||
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
@@ -178,65 +177,6 @@ class Settings:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
def get_memory_config_path(self, config_file: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module configuration files.
|
||||
|
||||
Args:
|
||||
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
|
||||
|
||||
Returns:
|
||||
Full path to the config file
|
||||
"""
|
||||
if not config_file:
|
||||
config_file = self.MEMORY_CONFIG_FILE
|
||||
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
|
||||
|
||||
def load_memory_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load memory module configuration from config.json.
|
||||
|
||||
Returns:
|
||||
Dictionary containing memory configuration
|
||||
"""
|
||||
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
|
||||
return {}
|
||||
|
||||
def load_memory_runtime_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load memory module runtime configuration from runtime.json.
|
||||
|
||||
Returns:
|
||||
Dictionary containing runtime configuration
|
||||
"""
|
||||
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
|
||||
try:
|
||||
with open(runtime_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
|
||||
return {"selections": {}}
|
||||
|
||||
def load_memory_dbrun_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load memory module database run configuration from dbrun.json.
|
||||
|
||||
Returns:
|
||||
Dictionary containing dbrun configuration
|
||||
"""
|
||||
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
|
||||
return {"selections": {}}
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -141,7 +141,7 @@ class SearchService:
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
try:
|
||||
# Execute search using embedding_model_id from memory_config
|
||||
# Execute search using memory_config
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
@@ -149,7 +149,7 @@ class SearchService:
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
embedding_id=str(config.embedding_model_id),
|
||||
memory_config=config,
|
||||
rerank_alpha=rerank_alpha,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ from app.core.memory.agent.mcp_server.models.retrieval_models import (
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
@@ -41,8 +42,10 @@ async def Data_type_differentiation(
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
# Get LLM client from memory_config using factory pattern
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Render template
|
||||
try:
|
||||
|
||||
@@ -16,7 +16,8 @@ from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
@@ -56,7 +57,9 @@ async def Split_The_Problem(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
@@ -190,7 +193,9 @@ async def Problem_Extension(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
|
||||
@@ -21,8 +21,9 @@ from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Summary_messages_deal,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
@@ -66,7 +67,9 @@ async def Summary(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -210,7 +213,9 @@ async def Retrieve_Summary(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -425,7 +430,9 @@ async def Input_Summary(
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""
|
||||
Type classification utility for distinguishing read/write operations.
|
||||
"""
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.config import settings
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,9 @@ async def status_typle(messages: str, llm_model_id: str) -> dict:
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
llm_client = get_llm_client(llm_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_model_id)
|
||||
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from typing import TypedDict, Annotated, List, Any
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
import asyncio
|
||||
import json
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
import os
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from langchain_core.messages import HumanMessage
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from typing import Annotated, Any, List, TypedDict
|
||||
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from langchain_core.messages import AnyMessage, HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
@@ -53,7 +54,9 @@ class VerifyTool:
|
||||
async def model_1(self, state: State) -> State:
|
||||
if not self.llm_model_id:
|
||||
raise ValueError("llm_model_id is required but not provided")
|
||||
llm_client = get_llm_client(self.llm_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
|
||||
)
|
||||
|
||||
@@ -13,13 +13,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.core.memory.utils.embedder.embedder_utils import (
|
||||
get_embedder_client_from_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
@@ -67,9 +65,11 @@ async def write(
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
|
||||
# Construct clients from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
embedder_client = get_embedder_client_from_config(memory_config)
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
embedder_client = factory.get_embedder_client_from_config(memory_config)
|
||||
logger.info("LLM and embedding clients constructed")
|
||||
|
||||
# Initialize timing log
|
||||
@@ -100,7 +100,7 @@ async def write(
|
||||
# Step 2: Initialize and run ExtractionOrchestrator
|
||||
step_start = time.time()
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
pipeline_config = get_pipeline_config()
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
@@ -155,8 +155,8 @@ async def write(
|
||||
# Step 4: Generate Memory summaries and save to Neo4j
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -35,7 +35,9 @@ except NameError:
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
# Default values (previously from definitions.py)
|
||||
@@ -47,11 +49,37 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
"""
|
||||
try:
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# 3. 构建Prompt
|
||||
tag_list_str = ", ".join(tags)
|
||||
@@ -156,8 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
llm_client = get_llm_client(DEFAULT_LLM_ID)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
|
||||
@@ -18,8 +18,10 @@ if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
@@ -59,7 +61,33 @@ class MemoryInsight:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
self.llm_client = get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
|
||||
@@ -25,8 +25,10 @@ except Exception:
|
||||
pass
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
@@ -47,7 +49,33 @@ class UserSummary:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
self.llm = get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
|
||||
@@ -1,22 +1,34 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
@@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline(
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
@@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline(
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
@@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline(
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedding_id=embedding_name or SELECTED_EMBEDDING_ID
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -23,37 +23,38 @@ except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID
|
||||
)
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
f1_score,
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
avg_context_tokens
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
get_category_name,
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
get_category_name
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
load_locomo_data,
|
||||
extract_conversations,
|
||||
ingest_conversations_if_needed,
|
||||
load_locomo_data,
|
||||
resolve_temporal_references,
|
||||
select_and_format_information,
|
||||
retrieve_relevant_information,
|
||||
ingest_conversations_if_needed
|
||||
select_and_format_information,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
@@ -160,10 +161,16 @@ async def run_locomo_benchmark(
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize LLM client with database context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str:
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import f1_score, bleu1, jaccard
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
@@ -111,10 +113,14 @@ try:
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
@@ -429,13 +435,17 @@ async def run_enhanced_evaluation():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
@@ -458,10 +468,14 @@ async def run_enhanced_evaluation():
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -2,10 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
import statistics
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
@@ -13,16 +14,31 @@ except Exception:
|
||||
return None
|
||||
|
||||
import re
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
@@ -327,9 +343,13 @@ async def run_locomo_eval(
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -2,11 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,6 +16,7 @@ except Exception:
|
||||
|
||||
# 确保可以找到 src 及项目根路径
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
@@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
try:
|
||||
# 优先从 extraction_utils1 导入
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline, # type: ignore
|
||||
)
|
||||
except Exception:
|
||||
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
@@ -686,9 +701,13 @@ async def run_longmemeval_test(
|
||||
)
|
||||
|
||||
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
connector = Neo4jConnector()
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
@@ -748,10 +767,10 @@ async def run_longmemeval_test(
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
# for sm in summaries:
|
||||
# summary_text = str(sm.get("summary", "")).strip()
|
||||
# if summary_text:
|
||||
# contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要(最多3个)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
|
||||
@@ -2,11 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,15 +15,26 @@ except Exception:
|
||||
return None
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
@@ -647,9 +658,13 @@ async def run_longmemeval_test(
|
||||
items = qa_list[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化组件 - 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
connector = Neo4jConnector()
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -4,19 +4,35 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
@@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]:
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
@@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
@@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
@@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
@@ -2,10 +2,10 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,6 +15,7 @@ except Exception:
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
@@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
@@ -226,13 +237,17 @@ async def run_memsciqa_test(
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -5,18 +5,17 @@ OpenAI LLM 客户端实现
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.llm import RedBearLLM
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +42,7 @@ class OpenAIClient(LLMClient):
|
||||
|
||||
# 初始化 Langfuse 回调处理器(如果启用)
|
||||
self.langfuse_handler = None
|
||||
if LANGFUSE_ENABLED:
|
||||
if settings.LANGFUSE_ENABLED:
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
self.langfuse_handler = CallbackHandler()
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
"""
|
||||
MemSci 记忆系统主入口 - 重构版本
|
||||
|
||||
该模块是重构后的记忆系统主入口,使用新的模块化架构。
|
||||
旧版本入口(app/core/memory/src/main.py)已删除。
|
||||
|
||||
主要功能:
|
||||
1. 协调整个知识提取流水线
|
||||
2. 支持试运行模式和正常运行模式
|
||||
3. 使用重构后的 storage_services 模块
|
||||
4. 提供统一的配置管理和日志记录
|
||||
|
||||
作者:Lance77
|
||||
日期:2025-11-22
|
||||
"""
|
||||
|
||||
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
|
||||
import os
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
||||
os.environ["LANGCHAIN_TRACING"] = "false"
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 导入重构后的模块
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
|
||||
# 导入数据加载函数
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
get_chunked_dialogs_with_preprocessing,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def main(
|
||||
# Required configuration parameters (no longer from global variables)
|
||||
chunker_strategy: str,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
llm_model_id: str,
|
||||
embedding_model_id: str,
|
||||
# Optional parameters
|
||||
dialogue_text: Optional[str] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本 (Updated to eliminate global variables)
|
||||
|
||||
该函数是重构后的主入口,使用新的模块化架构。
|
||||
Global variables have been eliminated in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
chunker_strategy: Chunking strategy to use (required)
|
||||
group_id: Group ID for the operation (required)
|
||||
user_id: User ID for the operation (required)
|
||||
apply_id: Application ID for the operation (required)
|
||||
llm_model_id: LLM model ID to use (required)
|
||||
embedding_model_id: Embedding model ID to use (required)
|
||||
dialogue_text: 输入的对话文本(可选,用于试运行模式)
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
- False: 正常运行模式,保存到 Neo4j
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
|
||||
工作流程:
|
||||
1. 初始化客户端和配置
|
||||
2. 加载或准备数据
|
||||
3. 执行知识提取流水线
|
||||
4. 保存结果(正常模式)或输出结果(试运行模式)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("MemSci 知识提取流水线 - 重构版本")
|
||||
print("=" * 60)
|
||||
print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}")
|
||||
print("Using chunker strategy:", chunker_strategy)
|
||||
print("Using group ID:", group_id)
|
||||
print("Using model ID:", llm_model_id)
|
||||
print("Using embedding model ID:", embedding_model_id)
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化日志
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
llm_client = get_llm_client(llm_model_id)
|
||||
|
||||
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
embedder_config_dict = get_embedder_config(embedding_model_id)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 加载或准备数据
|
||||
logger.info("Loading data...")
|
||||
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
|
||||
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
|
||||
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
|
||||
step_start = time.time()
|
||||
|
||||
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
|
||||
# 试运行模式:处理前端传入的对话文本
|
||||
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
|
||||
import re
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
# 创建对话上下文和对话数据
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
else:
|
||||
# 正常运行模式:从 testdata.json 文件加载
|
||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||
logger.info("Loading data from testdata.json...")
|
||||
test_data_path = os.path.join(
|
||||
os.path.dirname(__file__), "data", "testdata.json"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_data_path):
|
||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
indices=None,
|
||||
input_data_path=test_data_path,
|
||||
llm_client=llm_client,
|
||||
skip_cleaning=True,
|
||||
)
|
||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
# 从 runtime.json 加载配置(已经过数据库覆写)
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback, # 传递进度回调
|
||||
embedding_id=embedding_model_id, # 传递嵌入模型ID
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
|
||||
# 进度回调:正在知识抽取
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple
|
||||
# extraction_result 是一个包含 7 个元素的 tuple:
|
||||
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
|
||||
# statement_chunk_edges, statement_entity_edges, entity_edges)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# 进度回调:生成结果
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
|
||||
# 步骤 5: 保存结果或输出结果
|
||||
if is_pilot_run:
|
||||
logger.info("Pilot run mode: Skipping Neo4j save")
|
||||
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
|
||||
print("提取结果已生成,可在相关输出中查看。")
|
||||
else:
|
||||
logger.info("Normal mode: Saving to Neo4j...")
|
||||
step_start = time.time()
|
||||
|
||||
# 创建索引和约束
|
||||
try:
|
||||
from app.repositories.neo4j.create_indexes import (
|
||||
create_fulltext_indexes,
|
||||
create_unique_constraints,
|
||||
)
|
||||
await create_fulltext_indexes()
|
||||
await create_unique_constraints()
|
||||
logger.info("Successfully created indexes and constraints")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes/constraints: {e}")
|
||||
|
||||
# 保存数据到 Neo4j
|
||||
try:
|
||||
from app.repositories.neo4j.graph_saver import (
|
||||
save_dialog_and_statements_to_neo4j,
|
||||
)
|
||||
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_edges=entity_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
print("\n✓ 成功保存所有数据到 Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
print("\n⚠ 部分数据保存到 Neo4j 失败")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
|
||||
print(f"\n✗ 保存到 Neo4j 失败: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 6: 生成记忆摘要(可选)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import (
|
||||
add_memory_summary_statement_edges,
|
||||
)
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
|
||||
)
|
||||
|
||||
if not is_pilot_run:
|
||||
# 保存记忆摘要到 Neo4j
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
await ms_connector.close()
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
|
||||
print(f"\n✗ 流水线执行失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 清理资源
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 记录总时间
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# 添加完成标记
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 流水线执行完成")
|
||||
print(f"✓ 总耗时: {total_time:.2f} 秒")
|
||||
print(f"✓ 详细日志: {log_file}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("⚠️ Warning: This script now requires explicit configuration parameters.")
|
||||
print("Global variables have been removed. Please provide configuration parameters.")
|
||||
print("Example usage:")
|
||||
print(" asyncio.run(main(")
|
||||
print(" chunker_strategy='RecursiveChunker',")
|
||||
print(" group_id='your_group_id',")
|
||||
print(" user_id='your_user_id',")
|
||||
print(" apply_id='your_apply_id',")
|
||||
print(" llm_model_id='your_llm_id',")
|
||||
print(" embedding_model_id='your_embedding_id'")
|
||||
print(" ))")
|
||||
|
||||
# This will fail because global variables are removed
|
||||
raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.")
|
||||
@@ -20,14 +20,14 @@ Classes:
|
||||
MemorySummaryNode: Node representing a memory summary
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import re
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.utils.alias_utils import validate_aliases
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
def parse_historical_datetime(v):
|
||||
@@ -361,7 +361,7 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(..., description="Summary of the fact about this entity")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@ Classes:
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
"""Represents an extracted entity from dialogue.
|
||||
|
||||
@@ -5,7 +5,10 @@ import math
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
@@ -14,16 +17,14 @@ from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
|
||||
ForgettingEngine,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_embedder_config,
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_chunk_id,
|
||||
@@ -34,6 +35,7 @@ from app.repositories.neo4j.graph_search import (
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
@@ -324,229 +326,229 @@ def apply_reranker_placeholder(
|
||||
If config enables reranker, annotate items with a final_score equal to combined_score
|
||||
and keep ordering. This is a no-op reranker to be replaced later.
|
||||
"""
|
||||
try:
|
||||
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
if not rc or not rc.get("enabled", False):
|
||||
return results
|
||||
# try:
|
||||
# rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
# except Exception as e:
|
||||
# logger.debug(f"Failed to load reranker config: {e}")
|
||||
# rc = {}
|
||||
# if not rc or not rc.get("enabled", False):
|
||||
# return results
|
||||
|
||||
top_k = int(rc.get("top_k", 100))
|
||||
model_name = rc.get("model", "placeholder")
|
||||
# top_k = int(rc.get("top_k", 100))
|
||||
# model_name = rc.get("model", "placeholder")
|
||||
|
||||
for cat, items in results.items():
|
||||
head = items[:top_k]
|
||||
for it in head:
|
||||
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
it["final_score"] = base
|
||||
it["reranker_model"] = model_name
|
||||
# Keep overall order by final_score if present, otherwise combined/score
|
||||
results[cat] = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
reverse=True,
|
||||
)
|
||||
# for cat, items in results.items():
|
||||
# head = items[:top_k]
|
||||
# for it in head:
|
||||
# base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
# it["final_score"] = base
|
||||
# it["reranker_model"] = model_name
|
||||
# # Keep overall order by final_score if present, otherwise combined/score
|
||||
# results[cat] = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
# reverse=True,
|
||||
# )
|
||||
return results
|
||||
|
||||
|
||||
async def apply_llm_reranker(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
reranker_client: Optional[Any] = None,
|
||||
llm_weight: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Apply LLM-based reranking to search results.
|
||||
# async def apply_llm_reranker(
|
||||
# results: Dict[str, List[Dict[str, Any]]],
|
||||
# query_text: str,
|
||||
# reranker_client: Optional[Any] = None,
|
||||
# llm_weight: Optional[float] = None,
|
||||
# top_k: Optional[int] = None,
|
||||
# batch_size: Optional[int] = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
Args:
|
||||
results: Search results organized by category
|
||||
query_text: Original search query
|
||||
reranker_client: Optional pre-initialized reranker client
|
||||
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
top_k: Maximum number of items to rerank per category
|
||||
batch_size: Number of items to process concurrently
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
# reranker_client: Optional pre-initialized reranker client
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
Returns:
|
||||
Reranked results with final_score and reranker_model fields
|
||||
"""
|
||||
# Load reranker configuration from runtime.json
|
||||
try:
|
||||
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
# # Load reranker configuration from runtime.json
|
||||
# # try:
|
||||
# # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
# Check if reranking is enabled
|
||||
enabled = rc.get("enabled", False)
|
||||
if not enabled:
|
||||
logger.debug("LLM reranking is disabled in configuration")
|
||||
return results
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
# Load configuration parameters with defaults
|
||||
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
# Initialize reranker client if not provided
|
||||
if reranker_client is None:
|
||||
try:
|
||||
reranker_client = get_reranker_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
return results
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
# reranker_client = get_reranker_client()
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
# Get model name for metadata
|
||||
model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
# Process each category
|
||||
reranked_results = {}
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
items = results.get(category, [])
|
||||
if not items:
|
||||
reranked_results[category] = []
|
||||
continue
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# items = results.get(category, [])
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
# Select top K items by combined_score for reranking
|
||||
sorted_items = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
reverse=True
|
||||
)
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
top_items = sorted_items[:top_k]
|
||||
remaining_items = sorted_items[top_k:]
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
# Extract text content from each item
|
||||
def extract_text(item: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a result item."""
|
||||
# Try different text fields based on category
|
||||
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
return str(text).strip()
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
# Batch items for concurrent processing
|
||||
batches = []
|
||||
for i in range(0, len(top_items), batch_size):
|
||||
batch = top_items[i:i + batch_size]
|
||||
batches.append(batch)
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
# Process batches concurrently
|
||||
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Process a batch of items with LLM relevance scoring."""
|
||||
scored_batch = []
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
for item in batch:
|
||||
item_text = extract_text(item)
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
# Skip items with no text
|
||||
if not item_text:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = 0.0
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
continue
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["llm_relevance_score"] = 0.0
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
# Create relevance scoring prompt
|
||||
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
Query: {query_text}
|
||||
# Query: {query_text}
|
||||
|
||||
Result: {item_text}
|
||||
# Result: {item_text}
|
||||
|
||||
Respond with only a number between 0.0 and 1.0, where:
|
||||
- 0.0 means completely irrelevant
|
||||
- 1.0 means perfectly relevant
|
||||
# Respond with only a number between 0.0 and 1.0, where:
|
||||
# - 0.0 means completely irrelevant
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
Relevance score:"""
|
||||
# Relevance score:"""
|
||||
|
||||
# Send request to LLM
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = await reranker_client.chat(messages)
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
# Parse LLM response to extract relevance score
|
||||
response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
# Try to extract a float from the response
|
||||
try:
|
||||
# Remove any non-numeric characters except decimal point
|
||||
import re
|
||||
score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
if score_match:
|
||||
llm_score = float(score_match.group(1))
|
||||
# Clamp to [0.0, 1.0]
|
||||
llm_score = max(0.0, min(1.0, llm_score))
|
||||
else:
|
||||
raise ValueError("No numeric score found in response")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
llm_score = None
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
# import re
|
||||
# score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
# if score_match:
|
||||
# llm_score = float(score_match.group(1))
|
||||
# # Clamp to [0.0, 1.0]
|
||||
# llm_score = max(0.0, min(1.0, llm_score))
|
||||
# else:
|
||||
# raise ValueError("No numeric score found in response")
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
# Calculate final score
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
if llm_score is not None:
|
||||
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
item_copy["llm_relevance_score"] = llm_score
|
||||
else:
|
||||
# Use combined_score as fallback
|
||||
final_score = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
# else:
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
item_copy["final_score"] = final_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
return scored_batch
|
||||
# return scored_batch
|
||||
|
||||
# Process all batches concurrently
|
||||
try:
|
||||
batch_tasks = [process_batch(batch) for batch in batches]
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Merge batch results
|
||||
scored_items = []
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"Batch processing failed: {result}")
|
||||
continue
|
||||
scored_items.extend(result)
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
# if isinstance(result, Exception):
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
# Add remaining items (not in top K) with their combined_score as final_score
|
||||
for item in remaining_items:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_items.append(item_copy)
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
# Sort all items by final_score in descending order
|
||||
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
reranked_results[category] = scored_items
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# Return original items with combined_score as final_score
|
||||
for item in items:
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item["final_score"] = combined_score
|
||||
item["reranker_model"] = model_name
|
||||
reranked_results[category] = items
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
# for item in items:
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
return reranked_results
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
@@ -556,7 +558,7 @@ async def run_hybrid_search(
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
embedding_id: str,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -564,11 +566,14 @@ async def run_hybrid_search(
|
||||
"""
|
||||
|
||||
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing embedding_model_id and config_id
|
||||
"""
|
||||
# Start overall timing
|
||||
search_start_time = time.time()
|
||||
latency_metrics = {}
|
||||
logger.info(f"using embedding_id:{embedding_id}...")
|
||||
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
@@ -621,7 +626,9 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
embedder_config_dict = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
@@ -683,7 +690,7 @@ async def run_hybrid_search(
|
||||
if use_forgetting_rerank:
|
||||
# Load forgetting parameters from pipeline config
|
||||
try:
|
||||
pc = get_pipeline_config()
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
@@ -711,16 +718,16 @@ async def run_hybrid_search(
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
if use_llm_rerank:
|
||||
try:
|
||||
reranked_results = await apply_llm_reranker(
|
||||
results=reranked_results,
|
||||
query_text=query_text,
|
||||
)
|
||||
llm_rerank_applied = True
|
||||
logger.info("LLM reranking applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
# if use_llm_rerank:
|
||||
# try:
|
||||
# reranked_results = await apply_llm_reranker(
|
||||
# results=reranked_results,
|
||||
# query_text=query_text,
|
||||
# )
|
||||
# llm_rerank_applied = True
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
@@ -896,90 +903,95 @@ async def search_chunk_by_chunk_id(
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the hybrid graph search CLI.
|
||||
# def main():
|
||||
# """Main entry point for the hybrid graph search CLI.
|
||||
|
||||
Parses command line arguments and executes search with specified parameters.
|
||||
Supports keyword, embedding, and hybrid search modes.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
parser.add_argument(
|
||||
"--query", "-q", required=True, help="Free-text query to search"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search-type",
|
||||
"-t",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
default="hybrid",
|
||||
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-name",
|
||||
"-m",
|
||||
default="openai/nomic-embed-text:v1.5",
|
||||
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-id",
|
||||
"-g",
|
||||
default=None,
|
||||
help="Optional group_id to filter results (default: None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
"-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Max number of results per type (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
"-i",
|
||||
nargs="+",
|
||||
default=["statements", "chunks", "entities", "summaries"],
|
||||
choices=["statements", "chunks", "entities", "summaries"],
|
||||
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="search_results.json",
|
||||
help="Path to save the search results JSON (default: search_results.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rerank-alpha",
|
||||
"-a",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forgetting-rerank",
|
||||
action="store_true",
|
||||
help="Apply forgetting curve during reranking for hybrid search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-rerank",
|
||||
action="store_true",
|
||||
help="Apply LLM-based reranking for hybrid search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# Parses command line arguments and executes search with specified parameters.
|
||||
# Supports keyword, embedding, and hybrid search modes.
|
||||
# """
|
||||
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
# parser.add_argument(
|
||||
# "--query", "-q", required=True, help="Free-text query to search"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--search-type",
|
||||
# "-t",
|
||||
# choices=["keyword", "embedding", "hybrid"],
|
||||
# default="hybrid",
|
||||
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--config-id",
|
||||
# "-c",
|
||||
# type=int,
|
||||
# required=True,
|
||||
# help="Database configuration ID (required)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--group-id",
|
||||
# "-g",
|
||||
# default=None,
|
||||
# help="Optional group_id to filter results (default: None)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--limit",
|
||||
# "-k",
|
||||
# type=int,
|
||||
# default=5,
|
||||
# help="Max number of results per type (default: 5)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--include",
|
||||
# "-i",
|
||||
# nargs="+",
|
||||
# default=["statements", "chunks", "entities", "summaries"],
|
||||
# choices=["statements", "chunks", "entities", "summaries"],
|
||||
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--output",
|
||||
# "-o",
|
||||
# default="search_results.json",
|
||||
# help="Path to save the search results JSON (default: search_results.json)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--rerank-alpha",
|
||||
# "-a",
|
||||
# type=float,
|
||||
# default=0.6,
|
||||
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--forgetting-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply forgetting curve during reranking for hybrid search.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--llm-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply LLM-based reranking for hybrid search.",
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
run_hybrid_search(
|
||||
query_text=args.query,
|
||||
search_type=args.search_type,
|
||||
group_id=args.group_id,
|
||||
limit=args.limit,
|
||||
include=args.include,
|
||||
output_path=args.output,
|
||||
embedding_id=config_defs.SELECTED_EMBEDDING_ID,
|
||||
rerank_alpha=args.rerank_alpha,
|
||||
use_forgetting_rerank=args.forgetting_rerank,
|
||||
use_llm_rerank=args.llm_rerank,
|
||||
)
|
||||
)
|
||||
# # Load memory config from database
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
|
||||
|
||||
# asyncio.run(
|
||||
# run_hybrid_search(
|
||||
# query_text=args.query,
|
||||
# search_type=args.search_type,
|
||||
# group_id=args.group_id,
|
||||
# limit=args.limit,
|
||||
# include=args.include,
|
||||
# output_path=args.output,
|
||||
# memory_config=memory_config,
|
||||
# rerank_alpha=args.rerank_alpha,
|
||||
# use_forgetting_rerank=args.forgetting_rerank,
|
||||
# use_llm_rerank=args.llm_rerank,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
去重功能函数
|
||||
"""
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from typing import List, Dict, Tuple, Any
|
||||
from app.core.memory.models.graph_models import(
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode
|
||||
)
|
||||
import os
|
||||
from datetime import datetime
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
"""统一实体类型:基于LLM建议或启发式规则选择最合适的类型。
|
||||
@@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
config: DedupConfig,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
|
||||
"""
|
||||
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
|
||||
@@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
"""
|
||||
llm_records: List[str] = []
|
||||
try:
|
||||
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
|
||||
enable_switch = (
|
||||
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
|
||||
)
|
||||
if not enable_switch:
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
|
||||
_defaults = DedupConfig()
|
||||
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
|
||||
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
|
||||
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
|
||||
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
|
||||
|
||||
# 动态导入 llm 客户端(修正导入路径)
|
||||
try:
|
||||
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
|
||||
get_llm_client_fn = llm_utils_mod.get_llm_client
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
|
||||
if not bool(config.enable_llm_dedup_blockwise):
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数
|
||||
block_size = config.llm_block_size
|
||||
block_concurrency = config.llm_block_concurrency
|
||||
pair_concurrency = config.llm_pair_concurrency
|
||||
max_rounds = config.llm_max_rounds
|
||||
|
||||
try:
|
||||
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
|
||||
@@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
# 获取 LLM 客户端
|
||||
try:
|
||||
llm_client = get_llm_client_fn()
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
|
||||
# 验证 LLM 客户端
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端未提供")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
llm_redirect, llm_records = await llm_fn(
|
||||
@@ -813,7 +799,8 @@ async def LLM_disamb_decision(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
config: DedupConfig,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
|
||||
"""
|
||||
预消歧阶段:对“同名但类型不同”的实体对调用LLM进行消歧,
|
||||
@@ -824,22 +811,16 @@ async def LLM_disamb_decision(
|
||||
disamb_records: List[str] = []
|
||||
blocked_pairs: set[tuple[str, str]] = set()
|
||||
try:
|
||||
enable_switch = (
|
||||
config.enable_llm_disambiguation
|
||||
if config is not None
|
||||
else DedupConfig().enable_llm_disambiguation
|
||||
)
|
||||
if not bool(enable_switch):
|
||||
if not bool(config.enable_llm_disambiguation):
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
|
||||
llm_disambiguate_pairs_iterative,
|
||||
)
|
||||
|
||||
# 获取 LLM 客户端并验证
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
# 验证 LLM 客户端
|
||||
if llm_client is None:
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None")
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端未提供")
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
|
||||
@@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges(
|
||||
report_append: bool = False,
|
||||
report_stage_notes: List[str] | None = None,
|
||||
dedup_config: DedupConfig | None = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[
|
||||
List[ExtractedEntityNode],
|
||||
List[StatementEntityEdge],
|
||||
@@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges(
|
||||
|
||||
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
|
||||
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
|
||||
)
|
||||
|
||||
# 2) 模糊匹配(本地规则)
|
||||
@@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges(
|
||||
|
||||
if should_trigger_llm:
|
||||
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
|
||||
)
|
||||
else:
|
||||
llm_decision_records = []
|
||||
|
||||
@@ -10,15 +10,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明
|
||||
_write_dedup_fusion_report,
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
get_dedup_candidates_for_entities, # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from app.repositories.neo4j.neo4j_connector import (
|
||||
Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
)
|
||||
|
||||
|
||||
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt,用于将任意类型的输入值解析为datetime对象(处理实体节点中的时间字段)
|
||||
@@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
dedup_config: DedupConfig | None = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
@@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
|
||||
|
||||
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges(
|
||||
union_entities,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
report_stage="第二层去重消歧",
|
||||
report_append=True,
|
||||
dedup_config=dedup_config,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
@@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: Optional[ExtractionPipelineConfig] = None,
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return(
|
||||
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
|
||||
"""
|
||||
|
||||
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
|
||||
# pipeline_config is required - caller must provide it
|
||||
if pipeline_config is None:
|
||||
try:
|
||||
pipeline_config = get_pipeline_config()
|
||||
except Exception:
|
||||
pipeline_config = None
|
||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
@@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
report_stage="第一层去重消歧",
|
||||
report_append=False,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
# 初始化第二层融合结果为第一层结果
|
||||
@@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
llm_client=llm_client,
|
||||
)
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
|
||||
@@ -1253,6 +1253,7 @@ class ExtractionOrchestrator:
|
||||
report_stage="第一层去重消歧(试运行)",
|
||||
report_append=False,
|
||||
dedup_config=self.config.deduplication,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
@@ -1284,6 +1285,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
self.config,
|
||||
self.connector,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 解包返回值
|
||||
|
||||
@@ -5,11 +5,13 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
@@ -21,7 +23,9 @@ class EmbeddingGenerator:
|
||||
Args:
|
||||
embedding_id: 嵌入模型 ID
|
||||
"""
|
||||
embedder_config = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config = config_service.get_embedder_config(embedding_id)
|
||||
self.embedder_client = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_config),
|
||||
)
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
@@ -91,22 +87,17 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
|
||||
async def Memory_summary_generation(
|
||||
async def memory_summary_generation(
|
||||
chunked_dialogs: List[DialogData],
|
||||
llm_client,
|
||||
embedding_id,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
) -> List[MemorySummaryNode]:
|
||||
"""Generate memory summaries per chunk, embed them, and return nodes."""
|
||||
embedder_cfg_dict = get_embedder_config(embedding_id)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
|
||||
)
|
||||
|
||||
# Collect all tasks for parallel processing
|
||||
tasks = []
|
||||
for dialog in chunked_dialogs:
|
||||
for chunk in dialog.chunks:
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder))
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
|
||||
|
||||
# Process all chunks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
|
||||
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||
from app.core.memory.utils.data.ontology import (
|
||||
LABEL_DEFINITIONS,
|
||||
RelevenceInfo,
|
||||
StatementType,
|
||||
TemporalInfo,
|
||||
)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,21 +8,25 @@
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
neo4j_query_all,
|
||||
neo4j_query_part,
|
||||
neo4j_statement_all,
|
||||
neo4j_statement_part,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
@@ -135,14 +139,20 @@ class ReflectionEngine:
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
if self.llm_client is None:
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
model_id = self.llm_client
|
||||
self.llm_client = get_llm_client(model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(model_id)
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
@@ -154,11 +164,15 @@ class ReflectionEngine:
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
)
|
||||
self.render_evaluate_prompt_func = render_evaluate_prompt
|
||||
|
||||
if self.render_reflexion_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
self.render_reflexion_prompt_func = render_reflexion_prompt
|
||||
|
||||
if self.conflict_schema is None:
|
||||
@@ -170,7 +184,9 @@ class ReflectionEngine:
|
||||
self.reflexion_schema = ReflexionResultSchema
|
||||
|
||||
if self.update_query is None:
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
)
|
||||
self.update_query = UPDATE_STATEMENT_INVALID_AT
|
||||
|
||||
self._lazy_init_done = True
|
||||
|
||||
@@ -4,10 +4,20 @@
|
||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.storage_services.search.semantic_search import (
|
||||
SemanticSearchStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SearchStrategy",
|
||||
@@ -34,7 +44,7 @@ async def run_hybrid_search(
|
||||
include: list[str] | None = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
embedding_id: str | None = None,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""运行混合搜索(向后兼容的函数式API)
|
||||
@@ -51,24 +61,26 @@ async def run_hybrid_search(
|
||||
include: 要包含的搜索类别列表
|
||||
alpha: BM25分数权重(0.0-1.0)
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
embedding_id: 嵌入模型ID
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 使用提供的embedding_id或默认值
|
||||
emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID
|
||||
if not memory_config:
|
||||
raise ValueError("memory_config is required for search")
|
||||
|
||||
# 初始化客户端
|
||||
connector = Neo4jConnector()
|
||||
embedder_config_dict = get_embedder_config(emb_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
|
||||
@@ -5,15 +5,20 @@
|
||||
使用余弦相似度进行语义匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
@@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
"""
|
||||
try:
|
||||
# 从数据库读取嵌入器配置
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
|
||||
@@ -9,7 +9,6 @@ from .config_utils import (
|
||||
get_chunker_config,
|
||||
get_embedder_config,
|
||||
get_model_config,
|
||||
get_neo4j_config,
|
||||
get_picture_config,
|
||||
get_pipeline_config,
|
||||
get_pruning_config,
|
||||
@@ -41,7 +40,6 @@ __all__ = [
|
||||
# config_utils
|
||||
"get_model_config",
|
||||
"get_embedder_config",
|
||||
"get_neo4j_config",
|
||||
"get_chunker_config",
|
||||
"get_pipeline_config",
|
||||
"get_pruning_config",
|
||||
|
||||
@@ -1,90 +1,74 @@
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import CONFIG
|
||||
from app.db import get_db
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
"""
|
||||
Configuration utilities - Backward compatibility layer
|
||||
|
||||
DEPRECATED: These functions now require a db session parameter.
|
||||
New code should use MemoryConfigService(db) instance directly.
|
||||
|
||||
For functions that don't require db (get_pipeline_config, get_pruning_config),
|
||||
they are still re-exported here.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# These functions don't require db - safe to re-export as static methods
|
||||
get_pipeline_config = MemoryConfigService.get_pipeline_config
|
||||
get_pruning_config = MemoryConfigService.get_pruning_config
|
||||
|
||||
|
||||
def get_model_config(model_id: str, db: Session | None = None) -> dict:
|
||||
def get_model_config(model_id: str, db=None):
|
||||
"""DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly."""
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
raise ValueError(
|
||||
"get_model_config now requires a db session. "
|
||||
"Use MemoryConfigService(db).get_model_config(model_id) directly."
|
||||
)
|
||||
return MemoryConfigService(db).get_model_config(model_id)
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
print(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
|
||||
# 从环境变量读取超时和重试配置
|
||||
from app.core.config import settings
|
||||
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免 LLM 请求超时
|
||||
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取,默认120秒
|
||||
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取,默认2次
|
||||
}
|
||||
# 写入model_config.log文件中
|
||||
with open("logs/model_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"模型ID: {model_id}\n")
|
||||
f.write(f"模型配置信息:\n{model_config}\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
|
||||
def get_embedder_config(embedding_id: str, db=None):
|
||||
"""DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."""
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
raise ValueError(
|
||||
"get_embedder_config now requires a db session. "
|
||||
"Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."
|
||||
)
|
||||
return MemoryConfigService(db).get_embedder_config(embedding_id)
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
|
||||
if not config:
|
||||
print(f"嵌入模型ID {embedding_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
# Ensure required field for RedBearModelConfig validation
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免嵌入服务请求超时
|
||||
"timeout": 120.0, # 嵌入服务超时时间(秒)
|
||||
"max_retries": 5, # 最大重试次数
|
||||
}
|
||||
# 写入embedder_config.log文件中
|
||||
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"嵌入模型ID: {embedding_id}\n")
|
||||
f.write(f"嵌入模型配置信息:\n{model_config}\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_neo4j_config() -> dict:
|
||||
"""Retrieves the Neo4j configuration from the config file."""
|
||||
return CONFIG.get("neo4j", {})
|
||||
def get_picture_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
"""Retrieves the configuration for a specific model from the config file.
|
||||
|
||||
.. deprecated::
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use database-backed model configuration instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_picture_config is deprecated and will be removed in a future version. "
|
||||
"Use database-backed model configuration instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
for model_config in CONFIG.get("picture_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
raise ValueError(f"Model '{llm_name}' not found in config.json")
|
||||
|
||||
|
||||
def get_voice_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
"""Retrieves the configuration for a specific model from the config file.
|
||||
|
||||
.. deprecated::
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use database-backed model configuration instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_voice_config is deprecated and will be removed in a future version. "
|
||||
"Use database-backed model configuration instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
for model_config in CONFIG.get("voice_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
@@ -92,19 +76,8 @@ def get_voice_config(llm_name: str) -> dict:
|
||||
|
||||
|
||||
def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
"""Retrieves the configuration for a specific chunker strategy.
|
||||
"""Retrieves the configuration for a specific chunker strategy."""
|
||||
|
||||
Enhancements:
|
||||
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
|
||||
- Falls back to the first available chunker config when the requested one is missing.
|
||||
"""
|
||||
# 1) Try to find exact match in config
|
||||
chunker_list = CONFIG.get("chunker_list", [])
|
||||
for chunker_config in chunker_list:
|
||||
if chunker_config.get("chunker_strategy") == chunker_strategy:
|
||||
return chunker_config
|
||||
|
||||
# 2) Provide sane defaults for newer strategies
|
||||
default_configs = {
|
||||
"RecursiveChunker": {
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
@@ -112,7 +85,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
|
||||
"LLMChunker": {
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
@@ -137,127 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
if chunker_strategy in default_configs:
|
||||
return default_configs[chunker_strategy]
|
||||
|
||||
# 3) Fallback: use first available config but tag with requested strategy
|
||||
if chunker_list:
|
||||
fallback = chunker_list[0].copy()
|
||||
fallback["chunker_strategy"] = chunker_strategy
|
||||
# Non-fatal notice for visibility in logs if any
|
||||
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
|
||||
return fallback
|
||||
|
||||
# 4) If no configs available at all
|
||||
raise ValueError(
|
||||
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
|
||||
f"Chunker '{chunker_strategy}' not found "
|
||||
)
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
def get_pipeline_config(
|
||||
config_id: int,
|
||||
db: Session | None = None,
|
||||
) -> ExtractionPipelineConfig:
|
||||
"""Build ExtractionPipelineConfig from database.
|
||||
|
||||
Args:
|
||||
config_id: Database configuration ID (required). Loads pipeline
|
||||
settings from the DataConfig table.
|
||||
db: Optional database session. If not provided, a new session
|
||||
will be created.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings loaded from database.
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id not found in database.
|
||||
"""
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
# Load from database
|
||||
if db is None:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
db_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"Configuration {config_id} not found in database")
|
||||
|
||||
# Build DedupConfig from database
|
||||
dedup_kwargs = {
|
||||
"enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False,
|
||||
"enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False,
|
||||
}
|
||||
|
||||
# Fuzzy thresholds
|
||||
if db_config.t_name_strict is not None:
|
||||
dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict
|
||||
if db_config.t_type_strict is not None:
|
||||
dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict
|
||||
if db_config.t_overall is not None:
|
||||
dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall
|
||||
|
||||
dedup_config = DedupConfig(**dedup_kwargs)
|
||||
|
||||
# Build StatementExtractionConfig from database
|
||||
stmt_kwargs = {}
|
||||
if db_config.statement_granularity is not None:
|
||||
stmt_kwargs["statement_granularity"] = db_config.statement_granularity
|
||||
if db_config.include_dialogue_context is not None:
|
||||
stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context)
|
||||
if db_config.max_context is not None:
|
||||
stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context
|
||||
|
||||
stmt_config = StatementExtractionConfig(**stmt_kwargs)
|
||||
|
||||
# Build ForgettingEngineConfig from database
|
||||
forget_kwargs = {}
|
||||
if db_config.offset is not None:
|
||||
forget_kwargs["offset"] = db_config.offset
|
||||
if db_config.lambda_time is not None:
|
||||
forget_kwargs["lambda_time"] = db_config.lambda_time
|
||||
if db_config.lambda_mem is not None:
|
||||
forget_kwargs["lambda_mem"] = db_config.lambda_mem
|
||||
|
||||
forget_config = ForgettingEngineConfig(**forget_kwargs)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
|
||||
def get_pruning_config(
|
||||
config_id: int,
|
||||
db: Session | None = None,
|
||||
) -> dict:
|
||||
"""Retrieve semantic pruning config from database.
|
||||
|
||||
Args:
|
||||
config_id: Database configuration ID (required).
|
||||
db: Optional database session.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str ("education" | "online_service" | "outbound")
|
||||
- pruning_threshold: float (0-0.9)
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id not found in database.
|
||||
"""
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
if db is None:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
db_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"Configuration {config_id} not found in database")
|
||||
|
||||
return {
|
||||
"pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False,
|
||||
"pruning_scene": db_config.pruning_scene or "education",
|
||||
"pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5,
|
||||
}
|
||||
|
||||
@@ -1,269 +1,268 @@
|
||||
"""
|
||||
配置加载模块 - DEPRECATED
|
||||
# """
|
||||
# 配置加载模块 - DEPRECATED
|
||||
|
||||
⚠️ DEPRECATION NOTICE ⚠️
|
||||
This module is deprecated and will be removed in a future version.
|
||||
Global configuration variables have been eliminated in favor of dependency injection.
|
||||
# ⚠️ DEPRECATION NOTICE ⚠️
|
||||
# This module is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
Use the new MemoryConfig system instead:
|
||||
- app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
|
||||
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
|
||||
|
||||
阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# """
|
||||
# import json
|
||||
# import os
|
||||
# import threading
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Any, Dict, Optional
|
||||
|
||||
#TODO: Fix this
|
||||
# #TODO: Fix this
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
# try:
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# Import unified configuration system
|
||||
try:
|
||||
from app.core.config import settings
|
||||
USE_UNIFIED_CONFIG = True
|
||||
except ImportError:
|
||||
USE_UNIFIED_CONFIG = False
|
||||
settings = None
|
||||
# # Import unified configuration system
|
||||
# try:
|
||||
# from app.core.config import settings
|
||||
# USE_UNIFIED_CONFIG = True
|
||||
# except ImportError:
|
||||
# USE_UNIFIED_CONFIG = False
|
||||
# settings = None
|
||||
|
||||
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# __file__ = app/core/memory/utils/config/definitions.py
|
||||
# os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# os.path.dirname(...) = app/core/memory/utils
|
||||
# os.path.dirname(...) = app/core/memory
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# # __file__ = app/core/memory/utils/config/definitions.py
|
||||
# # os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# # os.path.dirname(...) = app/core/memory/utils
|
||||
# # os.path.dirname(...) = app/core/memory
|
||||
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# DEPRECATED: Global configuration lock removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
# # DEPRECATED: Global configuration lock removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
|
||||
# DEPRECATED: Legacy config.json loading removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
CONFIG = {}
|
||||
# # DEPRECATED: Legacy config.json loading removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# CONFIG = {}
|
||||
|
||||
DEFAULT_VALUES = {
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"group_id": "group_123",
|
||||
"user_id": "default_user",
|
||||
"apply_id": "default_apply",
|
||||
"llm_agent_name": "openai/qwen-plus",
|
||||
"llm_verify_name": "openai/qwen-plus",
|
||||
"llm_image_recognition": "openai/qwen-plus",
|
||||
"llm_voice_recognition": "openai/qwen-plus",
|
||||
"prompt_level": "DEBUG",
|
||||
"reflexion_iteration_period": "3",
|
||||
"reflexion_range": "retrieval",
|
||||
"reflexion_baseline": "TIME",
|
||||
}
|
||||
# DEFAULT_VALUES = {
|
||||
# "llm_name": "openai/qwen-plus",
|
||||
# "embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
# "chunker_strategy": "RecursiveChunker",
|
||||
# "group_id": "group_123",
|
||||
# "user_id": "default_user",
|
||||
# "apply_id": "default_apply",
|
||||
# "llm_agent_name": "openai/qwen-plus",
|
||||
# "llm_verify_name": "openai/qwen-plus",
|
||||
# "llm_image_recognition": "openai/qwen-plus",
|
||||
# "llm_voice_recognition": "openai/qwen-plus",
|
||||
# "prompt_level": "DEBUG",
|
||||
# "reflexion_iteration_period": "3",
|
||||
# "reflexion_range": "retrieval",
|
||||
# "reflexion_baseline": "TIME",
|
||||
# }
|
||||
|
||||
# DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# These will be removed in a future version
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
# # DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
"""
|
||||
DEPRECATED: Legacy runtime.json loading
|
||||
# # 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
# def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
# """
|
||||
# DEPRECATED: Legacy runtime.json loading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Empty configuration (legacy support only)
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return {"selections": {}}
|
||||
# Returns:
|
||||
# Dict[str, Any]: Empty configuration (legacy support only)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return {"selections": {}}
|
||||
|
||||
|
||||
# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# 保留此函数仅为向后兼容
|
||||
def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
DEPRECATED: Legacy database configuration loading
|
||||
# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# # 保留此函数仅为向后兼容
|
||||
# def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# DEPRECATED: Legacy database configuration loading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return None
|
||||
# Returns:
|
||||
# Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
"""
|
||||
DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
# """
|
||||
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Global configuration variables have been eliminated in favor of dependency injection.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
Use the new MemoryConfig system instead:
|
||||
- app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
- Pass configuration objects as parameters instead of using global variables
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
# - Pass configuration objects as parameters instead of using global variables
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Args:
|
||||
# runtime_cfg: 运行时配置字典
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# Keep minimal global state for backward compatibility only
|
||||
# These will be removed in a future version
|
||||
global RUNTIME_CONFIG, SELECTIONS
|
||||
# # Keep minimal global state for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# global RUNTIME_CONFIG, SELECTIONS
|
||||
|
||||
RUNTIME_CONFIG = runtime_cfg
|
||||
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
# RUNTIME_CONFIG = runtime_cfg
|
||||
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# All other global variables have been removed
|
||||
# Use MemoryConfig objects instead
|
||||
# # All other global variables have been removed
|
||||
# # Use MemoryConfig objects instead
|
||||
|
||||
|
||||
# 初始化:使用统一配置加载器
|
||||
def _initialize_configuration() -> None:
|
||||
"""
|
||||
DEPRECATED: Legacy configuration initialization
|
||||
# # 初始化:使用统一配置加载器
|
||||
# def _initialize_configuration() -> None:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration initialization
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Initialize with empty configuration for backward compatibility
|
||||
_expose_runtime_constants({"selections": {}})
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Initialize with empty configuration for backward compatibility
|
||||
# _expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# 模块加载时自动初始化配置
|
||||
_initialize_configuration()
|
||||
# # 模块加载时自动初始化配置
|
||||
# _initialize_configuration()
|
||||
|
||||
# DEPRECATED: Global variables removed
|
||||
# These variables have been eliminated in favor of dependency injection
|
||||
# Use MemoryConfig objects instead of accessing global variables
|
||||
# # DEPRECATED: Global variables removed
|
||||
# # These variables have been eliminated in favor of dependency injection
|
||||
# # Use MemoryConfig objects instead of accessing global variables
|
||||
|
||||
|
||||
# 公共 API:动态重新加载配置
|
||||
def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
"""
|
||||
DEPRECATED: Legacy configuration reloading
|
||||
# # 公共 API:动态重新加载配置
|
||||
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration reloading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
For new code, use:
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID (deprecated)
|
||||
force_reload: Force reload flag (deprecated)
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
# force_reload: Force reload flag (deprecated)
|
||||
|
||||
Returns:
|
||||
bool: Always returns False (deprecated functionality)
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
warnings.warn(
|
||||
"reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# warnings.warn(
|
||||
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
"Use MemoryConfig objects with dependency injection instead.")
|
||||
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
return False
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_current_config_id() -> Optional[str]:
|
||||
"""
|
||||
DEPRECATED: Legacy config ID retrieval
|
||||
# def get_current_config_id() -> Optional[str]:
|
||||
# """
|
||||
# DEPRECATED: Legacy config ID retrieval
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None (deprecated functionality)
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return None
|
||||
# Returns:
|
||||
# Optional[str]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
def ensure_fresh_config(config_id = None) -> bool:
|
||||
"""
|
||||
DEPRECATED: Legacy configuration freshness check
|
||||
# def ensure_fresh_config(config_id = None) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration freshness check
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
For new code, use:
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID (deprecated)
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
|
||||
Returns:
|
||||
bool: Always returns False (deprecated functionality)
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
warnings.warn(
|
||||
"ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# warnings.warn(
|
||||
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
"Use MemoryConfig objects with dependency injection instead.")
|
||||
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
return False
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,9 @@ This module provides centralized functions for creating embedder clients.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -66,7 +67,8 @@ def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
|
||||
raise ValueError("Embedding ID is required but was not provided")
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -13,105 +13,225 @@ async def handle_response(response: type[BaseModel]) -> dict:
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
class MemoryClientFactory:
|
||||
"""
|
||||
Get LLM client from MemoryConfig object.
|
||||
Factory for creating LLM, embedder, and reranker clients.
|
||||
|
||||
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
|
||||
This ensures proper configuration management and multi-tenant support.
|
||||
Initialize once with db session, then call methods without passing db each time.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing llm_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized LLM client
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM model ID is not configured or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> llm_client = get_llm_client_from_config(memory_config)
|
||||
>>> factory = MemoryClientFactory(db)
|
||||
>>> llm_client = factory.get_llm_client(model_id)
|
||||
>>> embedder_client = factory.get_embedder_client(embedding_id)
|
||||
"""
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no LLM model configured"
|
||||
)
|
||||
return get_llm_client(str(memory_config.llm_model_id))
|
||||
|
||||
def __init__(self, db: Session):
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
self._config_service = MemoryConfigService(db)
|
||||
|
||||
def get_llm_client(self, llm_id: str) -> OpenAIClient:
|
||||
"""Get LLM client by model ID."""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required")
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_embedder_client(self, embedding_id: str):
|
||||
"""Get embedder client by model ID."""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
if not embedding_id:
|
||||
raise ValueError("Embedding ID is required")
|
||||
|
||||
try:
|
||||
embedder_config = self._config_service.get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIEmbedderClient(
|
||||
RedBearModelConfig(
|
||||
model_name=embedder_config.get("model_name"),
|
||||
provider=embedder_config.get("provider"),
|
||||
api_key=embedder_config.get("api_key"),
|
||||
base_url=embedder_config.get("base_url")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = embedder_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
|
||||
"""Get reranker client by model ID."""
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required")
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""Get LLM client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing llm_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the LLM model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no LLM model configured
|
||||
"""
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no LLM model configured"
|
||||
)
|
||||
return self.get_llm_client(str(memory_config.llm_model_id))
|
||||
|
||||
def get_embedder_client_from_config(self, memory_config: "MemoryConfig"):
|
||||
"""Get embedder client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing embedding_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient configured for the embedding model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no embedding model configured
|
||||
"""
|
||||
if not memory_config.embedding_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no embedding model configured"
|
||||
)
|
||||
return self.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""Get reranker client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing rerank_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no rerank model configured
|
||||
"""
|
||||
if not memory_config.rerank_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no rerank model configured"
|
||||
)
|
||||
return self.get_reranker_client(str(memory_config.rerank_model_id))
|
||||
|
||||
|
||||
def get_llm_client(llm_id: str):
|
||||
"""
|
||||
Get LLM client by model ID.
|
||||
# Legacy functions for backward compatibility
|
||||
def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient:
|
||||
"""Get LLM client from MemoryConfig object.
|
||||
|
||||
**LEGACY/TEST METHOD**: Use this function only for:
|
||||
- Test/evaluation code where you have a model ID directly
|
||||
- Legacy code that hasn't been migrated to MemoryConfig yet
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead.
|
||||
|
||||
For production code with MemoryConfig, use get_llm_client_from_config() instead.
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_llm_client_from_config method directly.
|
||||
|
||||
Args:
|
||||
llm_id: LLM model ID (required)
|
||||
memory_config: Configuration containing llm_model_id
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized LLM client
|
||||
OpenAIClient configured for the LLM model
|
||||
|
||||
Raises:
|
||||
ValueError: If llm_id is not provided or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> # For tests/evaluations only
|
||||
>>> llm_client = get_llm_client("model-uuid-string")
|
||||
ValueError: If memory_config has no LLM model configured
|
||||
"""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
return llm_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
return MemoryClientFactory(db).get_llm_client_from_config(memory_config)
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str):
|
||||
"""
|
||||
Get an LLM client configured for reranking.
|
||||
def get_llm_client(llm_id: str, db: Session) -> OpenAIClient:
|
||||
"""Get LLM client by model ID.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_llm_client method directly.
|
||||
|
||||
Args:
|
||||
rerank_id: Reranker model ID (required)
|
||||
llm_id: LLM model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized client for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If rerank_id is not provided or client initialization fails
|
||||
OpenAIClient configured for the LLM model
|
||||
"""
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required but was not provided")
|
||||
return MemoryClientFactory(db).get_llm_client(llm_id)
|
||||
|
||||
|
||||
def get_embedder_client(embedding_id: str, db: Session):
|
||||
"""Get embedder client by model ID.
|
||||
|
||||
try:
|
||||
model_config = get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead.
|
||||
|
||||
try:
|
||||
reranker_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
return reranker_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_embedder_client method directly.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient configured for the embedding model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_embedder_client(embedding_id)
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient:
|
||||
"""Get reranker client by model ID.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_reranker_client method directly.
|
||||
|
||||
Args:
|
||||
rerank_id: Reranker model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the reranker model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_reranker_client(rerank_id)
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]:
|
||||
冲突记忆列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]:
|
||||
反思结果列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
@@ -5,16 +5,24 @@ This module provides functionality to analyze chunk content and generate insight
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
|
||||
|
||||
class ChunkInsight(BaseModel):
|
||||
"""Pydantic model for chunk insight."""
|
||||
insight: str = Field(..., description="对chunk内容的深度洞察分析")
|
||||
@@ -40,7 +48,7 @@ async def classify_chunk_domain(chunk: str) -> str:
|
||||
Domain name
|
||||
"""
|
||||
try:
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||
|
||||
@@ -177,7 +185,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
|
||||
]
|
||||
|
||||
# 调用LLM生成洞察
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
insight = response.content.strip()
|
||||
|
||||
@@ -5,15 +5,23 @@ This module provides functionality to summarize chunk content using LLM.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
|
||||
|
||||
class ChunkSummary(BaseModel):
|
||||
"""Pydantic model for chunk summary."""
|
||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||
@@ -59,7 +67,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
|
||||
]
|
||||
|
||||
# 调用LLM生成摘要
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
summary = response.content.strip()
|
||||
|
||||
@@ -7,14 +7,22 @@ This module provides functionality to extract meaningful tags from chunk content
|
||||
import asyncio
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
|
||||
|
||||
class ExtractedTags(BaseModel):
|
||||
"""Pydantic model for extracted tags."""
|
||||
tags: List[str] = Field(..., description="从文本中提取的关键标签列表")
|
||||
@@ -56,7 +64,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
# 为每个chunk单独提取标签,然后统计频率
|
||||
all_tags = []
|
||||
@@ -151,7 +159,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
||||
]
|
||||
|
||||
# 调用LLM提取人物形象
|
||||
llm_client = get_llm_client()
|
||||
llm_client = _get_llm_client()
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=ExtractedPersona
|
||||
|
||||
@@ -391,6 +391,29 @@ class MemoryConfig:
|
||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||
config_version: str = "2.0"
|
||||
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise: bool = False
|
||||
enable_llm_disambiguation: bool = False
|
||||
deep_retrieval: bool = True
|
||||
t_type_strict: float = 0.8
|
||||
t_name_strict: float = 0.8
|
||||
t_overall: float = 0.8
|
||||
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity: int = 2
|
||||
include_dialogue_context: bool = False
|
||||
max_dialogue_context_chars: int = 1000
|
||||
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time: float = 0.5
|
||||
lambda_mem: float = 0.5
|
||||
offset: float = 0.0
|
||||
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled: bool = False
|
||||
pruning_scene: Optional[str] = "education"
|
||||
pruning_threshold: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.config_name or not self.config_name.strip():
|
||||
|
||||
@@ -3,26 +3,26 @@
|
||||
|
||||
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.models import AgentConfig, ModelConfig, ModelApiKey
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.services.langchain_tool_server import Search
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"""
|
||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||
|
||||
@@ -713,9 +719,9 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当指定的会话不存在时
|
||||
"""
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.models import Conversation as ConversationModel
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
|
||||
|
||||
@@ -7,14 +7,15 @@ Classes:
|
||||
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import statistics
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
import statistics
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.neo4j.emotion_repository import EmotionRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_business_logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
|
||||
async def generate_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
config_id: Optional[int] = None
|
||||
db: Session,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成个性化情绪建议
|
||||
|
||||
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
config_id: 配置ID(可选,用于从数据库加载LLM配置)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 包含个性化建议的响应:
|
||||
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
|
||||
- suggestions: 建议列表(3-5条)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}")
|
||||
|
||||
# 1. 如果提供了 config_id,从数据库加载配置
|
||||
if config_id is not None:
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
|
||||
# 1. 从 end_user_id 获取关联的 memory_config_id
|
||||
llm_client = None
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is not None:
|
||||
from app.services.memory_config_service import (
|
||||
MemoryConfigService,
|
||||
)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(config_id),
|
||||
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
|
||||
)
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
|
||||
|
||||
# 2. 获取情绪健康数据
|
||||
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
|
||||
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
|
||||
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
|
||||
|
||||
# 7. 调用LLM生成建议(使用配置中的LLM)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
llm_client = get_llm_client()
|
||||
if llm_client is None:
|
||||
# 无法获取配置时,抛出错误而不是使用默认配置
|
||||
raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config")
|
||||
|
||||
# 将 prompt 转换为 messages 格式
|
||||
messages = [
|
||||
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
|
||||
Returns:
|
||||
str: LLM prompt
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_suggestions_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_suggestions_prompt(
|
||||
health_data=health_data,
|
||||
|
||||
@@ -9,10 +9,12 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.data_config_model import DataConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +52,9 @@ class EmotionExtractionService:
|
||||
"""
|
||||
if self.llm_client is None or model_id:
|
||||
effective_model_id = model_id or self.llm_id
|
||||
self.llm_client = get_llm_client(effective_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(effective_model_id)
|
||||
return self.llm_client
|
||||
|
||||
async def extract_emotion(
|
||||
@@ -142,7 +146,9 @@ class EmotionExtractionService:
|
||||
Returns:
|
||||
Formatted prompt string for LLM
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_extraction_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_extraction_prompt(
|
||||
statement=statement,
|
||||
|
||||
@@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
@@ -45,8 +45,7 @@ config_logger = get_config_logger()
|
||||
|
||||
# Initialize Neo4j connector for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
|
||||
class MemoryAgentService:
|
||||
"""Service for memory agent operations"""
|
||||
@@ -55,27 +54,6 @@ class MemoryAgentService:
|
||||
self.user_locks: Dict[str, Lock] = {}
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def load_memory_config(self, config_id: int) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
countext = re.findall(r'"status": "(.*?)",', messages)[0]
|
||||
@@ -277,14 +255,17 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
|
||||
async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Write operation result status
|
||||
@@ -292,14 +273,24 @@ class MemoryAgentService:
|
||||
Raises:
|
||||
ValueError: If config loading fails or write operation fails
|
||||
"""
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
@@ -366,6 +357,7 @@ class MemoryAgentService:
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: str,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> Dict:
|
||||
@@ -378,11 +370,14 @@ class MemoryAgentService:
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: User message
|
||||
history: Conversation history
|
||||
search_switch: Search mode switch
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Dict with 'answer' and 'intermediate_outputs' keys
|
||||
@@ -394,8 +389,13 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
|
||||
@@ -411,7 +411,11 @@ class MemoryAgentService:
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
@@ -696,7 +700,11 @@ class MemoryAgentService:
|
||||
logger.info("Classifying message type")
|
||||
|
||||
# Load configuration to get LLM model ID
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
@@ -865,7 +873,8 @@ class MemoryAgentService:
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -877,6 +886,7 @@ class MemoryAgentService:
|
||||
- end_user_id: 用户ID(可选)
|
||||
- current_user_id: 当前登录用户的ID(保留参数)
|
||||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||||
- db: 数据库会话(可选)
|
||||
|
||||
返回格式:
|
||||
{
|
||||
@@ -893,7 +903,7 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 根据 end_user_id 获取 end_user_name
|
||||
try:
|
||||
if end_user_id:
|
||||
if end_user_id and db:
|
||||
from app.repositories import end_user_repository
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
@@ -948,7 +958,9 @@ class MemoryAgentService:
|
||||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||||
|
||||
# 使用LLM提取标签
|
||||
llm_client = get_llm_client(llm_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_id)
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
@@ -1110,4 +1122,69 @@ class MemoryAgentService:
|
||||
# "msg": "解析失败",
|
||||
# "error_code": "DOC_PARSE_ERROR",
|
||||
# "data": {"error": str(e)}
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
|
||||
通过以下流程获取配置:
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
# 1. 获取 end_user 及其 app_id
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
logger.warning(f"End user not found: {end_user_id}")
|
||||
raise ValueError(f"终端用户不存在: {end_user_id}")
|
||||
|
||||
app_id = end_user.app_id
|
||||
logger.debug(f"Found end_user app_id: {app_id}")
|
||||
|
||||
# 2. 获取该应用的最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||||
.order_by(AppRelease.version.desc())
|
||||
)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
|
||||
if not latest_release:
|
||||
logger.warning(f"No active release found for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
|
||||
|
||||
# 3. 从 config 中提取 memory_config_id
|
||||
config = latest_release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(latest_release.id),
|
||||
"release_version": latest_release.version,
|
||||
"memory_config_id": memory_config_id
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||
return result
|
||||
@@ -3,7 +3,6 @@ Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
@@ -77,19 +76,29 @@ class MemoryConfigService:
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
Database session management is handled internally.
|
||||
|
||||
Usage:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method manages its own database session internally.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
@@ -100,27 +109,6 @@ class MemoryConfigService:
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
from app.db import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
return MemoryConfigService._load_memory_config_with_db(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
service_name=service_name,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@staticmethod
|
||||
def _load_memory_config_with_db(
|
||||
config_id: int,
|
||||
db: Session,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""Internal method that loads memory configuration with an existing db session."""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
@@ -137,7 +125,7 @@ class MemoryConfigService:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -160,7 +148,7 @@ class MemoryConfigService:
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
@@ -169,7 +157,7 @@ class MemoryConfigService:
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
@@ -183,7 +171,7 @@ class MemoryConfigService:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
@@ -194,7 +182,7 @@ class MemoryConfigService:
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
@@ -220,6 +208,25 @@ class MemoryConfigService:
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
@@ -262,3 +269,131 @@ class MemoryConfigService:
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict:
|
||||
"""Get LLM model configuration by ID.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with model configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
}
|
||||
|
||||
def get_embedder_config(self, embedding_id: str) -> dict:
|
||||
"""Get embedding model configuration by ID.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with embedder configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_config(memory_config: MemoryConfig):
|
||||
"""Build ExtractionPipelineConfig from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing all pipeline settings.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings.
|
||||
"""
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
|
||||
dedup_config = DedupConfig(
|
||||
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
|
||||
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
|
||||
fuzzy_name_threshold_strict=memory_config.t_name_strict,
|
||||
fuzzy_type_threshold_strict=memory_config.t_type_strict,
|
||||
fuzzy_overall_threshold=memory_config.t_overall,
|
||||
)
|
||||
|
||||
stmt_config = StatementExtractionConfig(
|
||||
statement_granularity=memory_config.statement_granularity,
|
||||
include_dialogue_context=memory_config.include_dialogue_context,
|
||||
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
|
||||
)
|
||||
|
||||
forget_config = ForgettingEngineConfig(
|
||||
offset=memory_config.offset,
|
||||
lambda_time=memory_config.lambda_time,
|
||||
lambda_mem=memory_config.lambda_mem,
|
||||
)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pruning_config(memory_config: MemoryConfig) -> dict:
|
||||
"""Retrieve semantic pruning config from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing pruning settings.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str
|
||||
- pruning_threshold: float
|
||||
"""
|
||||
return {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
|
||||
@@ -49,27 +49,6 @@ class MemoryStorageService:
|
||||
|
||||
def __init__(self):
|
||||
logger.info("MemoryStorageService initialized")
|
||||
|
||||
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryStorageService"
|
||||
)
|
||||
|
||||
async def get_storage_info(self) -> dict:
|
||||
"""
|
||||
@@ -293,7 +272,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
memory_config = MemoryConfigService.load_memory_config(
|
||||
config_service = MemoryConfigService(self.db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(cid),
|
||||
service_name="MemoryStorageService.pilot_run_stream"
|
||||
)
|
||||
@@ -320,13 +300,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.services.pilot_run_service import run_pilot_extraction
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||
await run_pilot_extraction(
|
||||
memory_config=memory_config,
|
||||
dialogue_text=dialogue_text,
|
||||
db=self.db,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
|
||||
219
api/app/services/pilot_run_service.py
Normal file
219
api/app/services/pilot_run_service.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Pilot Run Service - 试运行服务
|
||||
|
||||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
async def run_pilot_extraction(
|
||||
memory_config: MemoryConfig,
|
||||
dialogue_text: str,
|
||||
db: Session,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行试运行模式的知识提取流水线。
|
||||
|
||||
Args:
|
||||
memory_config: 从数据库加载的内存配置对象
|
||||
dialogue_text: 输入的对话文本
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典
|
||||
"""
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
neo4j_connector = None
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
client_factory = MemoryClientFactory(db)
|
||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 解析对话文本
|
||||
logger.info("Parsing dialogue text...")
|
||||
step_start = time.time()
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches
|
||||
if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=str(memory_config.workspace_id),
|
||||
user_id=str(memory_config.tenant_id),
|
||||
apply_id=str(memory_config.config_id),
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=memory_config.chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dlg in chunked_dialogs:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
config = get_pipeline_config(memory_config)
|
||||
logger.info(
|
||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if neo4j_connector:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|
||||
@@ -176,7 +176,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"""Celery task to process a read message via MemoryAgentService.
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: User message to process
|
||||
history: Conversation history
|
||||
search_switch: Search switch parameter
|
||||
@@ -190,9 +190,28 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Resolve config_id if None
|
||||
actual_config_id = config_id
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id)
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -246,7 +265,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Optional configuration ID
|
||||
|
||||
@@ -258,9 +277,28 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Resolve config_id if None
|
||||
actual_config_id = config_id
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
service = MemoryAgentService()
|
||||
return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id)
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = MemoryAgentService()
|
||||
return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
|
||||
Reference in New Issue
Block a user