refactor(memory): restructure memory system and improve configuration management
- Remove deprecated main.py entry point from memory module - Reorganize imports across controllers and services for consistency - Update emotion controller to pass db session instead of config_id to services - Enhance memory agent controller with db session parameter for status_type and user_profile endpoints - Refactor memory agent service to accept db parameter in classify_message_type method - Improve configuration handling in celery_app by removing automatic database reload - Update all memory-related services to use centralized config management - Standardize import ordering and remove unused imports across 50+ files - Add pilot_run_service for new pilot execution workflow - Refactor extraction engine, reflection engine, and search services for better modularity - Update LLM utilities and embedder configuration for improved flexibility - Enhance type classifier and verification tools with better error handling - Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
@@ -141,7 +141,7 @@ class SearchService:
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
try:
|
||||
# Execute search using embedding_model_id from memory_config
|
||||
# Execute search using memory_config
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
@@ -149,7 +149,7 @@ class SearchService:
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
embedding_id=str(config.embedding_model_id),
|
||||
memory_config=config,
|
||||
rerank_alpha=rerank_alpha,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ from app.core.memory.agent.mcp_server.models.retrieval_models import (
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
@@ -41,8 +42,10 @@ async def Data_type_differentiation(
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
# Get LLM client from memory_config using factory pattern
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Render template
|
||||
try:
|
||||
|
||||
@@ -16,7 +16,8 @@ from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
@@ -56,7 +57,9 @@ async def Split_The_Problem(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
@@ -190,7 +193,9 @@ async def Problem_Extension(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
|
||||
@@ -21,8 +21,9 @@ from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Summary_messages_deal,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
@@ -66,7 +67,9 @@ async def Summary(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -210,7 +213,9 @@ async def Retrieve_Summary(
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -425,7 +430,9 @@ async def Input_Summary(
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""
|
||||
Type classification utility for distinguishing read/write operations.
|
||||
"""
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.config import settings
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,9 @@ async def status_typle(messages: str, llm_model_id: str) -> dict:
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
llm_client = get_llm_client(llm_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_model_id)
|
||||
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from typing import TypedDict, Annotated, List, Any
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
import asyncio
|
||||
import json
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
import os
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from langchain_core.messages import HumanMessage
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from typing import Annotated, Any, List, TypedDict
|
||||
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from langchain_core.messages import AnyMessage, HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
@@ -53,7 +54,9 @@ class VerifyTool:
|
||||
async def model_1(self, state: State) -> State:
|
||||
if not self.llm_model_id:
|
||||
raise ValueError("llm_model_id is required but not provided")
|
||||
llm_client = get_llm_client(self.llm_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
|
||||
)
|
||||
|
||||
@@ -13,13 +13,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.core.memory.utils.embedder.embedder_utils import (
|
||||
get_embedder_client_from_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
@@ -67,9 +65,11 @@ async def write(
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
|
||||
# Construct clients from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
embedder_client = get_embedder_client_from_config(memory_config)
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
embedder_client = factory.get_embedder_client_from_config(memory_config)
|
||||
logger.info("LLM and embedding clients constructed")
|
||||
|
||||
# Initialize timing log
|
||||
@@ -100,7 +100,7 @@ async def write(
|
||||
# Step 2: Initialize and run ExtractionOrchestrator
|
||||
step_start = time.time()
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
pipeline_config = get_pipeline_config()
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
@@ -155,8 +155,8 @@ async def write(
|
||||
# Step 4: Generate Memory summaries and save to Neo4j
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -35,7 +35,9 @@ except NameError:
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
# Default values (previously from definitions.py)
|
||||
@@ -47,11 +49,37 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
"""
|
||||
try:
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# 3. 构建Prompt
|
||||
tag_list_str = ", ".join(tags)
|
||||
@@ -156,8 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
llm_client = get_llm_client(DEFAULT_LLM_ID)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
|
||||
@@ -18,8 +18,10 @@ if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
@@ -59,7 +61,33 @@ class MemoryInsight:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
self.llm_client = get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
|
||||
@@ -25,8 +25,10 @@ except Exception:
|
||||
pass
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
@@ -47,7 +49,33 @@ class UserSummary:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
self.llm = get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
|
||||
@@ -1,22 +1,34 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
@@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline(
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
@@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline(
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
@@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline(
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedding_id=embedding_name or SELECTED_EMBEDDING_ID
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -23,37 +23,38 @@ except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID
|
||||
)
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
f1_score,
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
avg_context_tokens
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
get_category_name,
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
get_category_name
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
load_locomo_data,
|
||||
extract_conversations,
|
||||
ingest_conversations_if_needed,
|
||||
load_locomo_data,
|
||||
resolve_temporal_references,
|
||||
select_and_format_information,
|
||||
retrieve_relevant_information,
|
||||
ingest_conversations_if_needed
|
||||
select_and_format_information,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
@@ -160,10 +161,16 @@ async def run_locomo_benchmark(
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize LLM client with database context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str:
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import f1_score, bleu1, jaccard
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
@@ -111,10 +113,14 @@ try:
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
@@ -429,13 +435,17 @@ async def run_enhanced_evaluation():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
@@ -458,10 +468,14 @@ async def run_enhanced_evaluation():
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -2,10 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
import statistics
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
@@ -13,16 +14,31 @@ except Exception:
|
||||
return None
|
||||
|
||||
import re
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
@@ -327,9 +343,13 @@ async def run_locomo_eval(
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -2,11 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,6 +16,7 @@ except Exception:
|
||||
|
||||
# 确保可以找到 src 及项目根路径
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
@@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
try:
|
||||
# 优先从 extraction_utils1 导入
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline, # type: ignore
|
||||
)
|
||||
except Exception:
|
||||
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
@@ -686,9 +701,13 @@ async def run_longmemeval_test(
|
||||
)
|
||||
|
||||
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
connector = Neo4jConnector()
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
@@ -748,10 +767,10 @@ async def run_longmemeval_test(
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
# for sm in summaries:
|
||||
# summary_text = str(sm.get("summary", "")).strip()
|
||||
# if summary_text:
|
||||
# contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要(最多3个)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
|
||||
@@ -2,11 +2,11 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,15 +15,26 @@ except Exception:
|
||||
return None
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
@@ -647,9 +658,13 @@ async def run_longmemeval_test(
|
||||
items = qa_list[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化组件 - 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
connector = Neo4jConnector()
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -4,19 +4,35 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
@@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]:
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
@@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
@@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
@@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
@@ -2,10 +2,10 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,6 +15,7 @@ except Exception:
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
@@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
@@ -226,13 +237,17 @@ async def run_memsciqa_test(
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
@@ -5,18 +5,17 @@ OpenAI LLM 客户端实现
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.llm import RedBearLLM
|
||||
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
|
||||
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +42,7 @@ class OpenAIClient(LLMClient):
|
||||
|
||||
# 初始化 Langfuse 回调处理器(如果启用)
|
||||
self.langfuse_handler = None
|
||||
if LANGFUSE_ENABLED:
|
||||
if settings.LANGFUSE_ENABLED:
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
self.langfuse_handler = CallbackHandler()
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
"""
|
||||
MemSci 记忆系统主入口 - 重构版本
|
||||
|
||||
该模块是重构后的记忆系统主入口,使用新的模块化架构。
|
||||
旧版本入口(app/core/memory/src/main.py)已删除。
|
||||
|
||||
主要功能:
|
||||
1. 协调整个知识提取流水线
|
||||
2. 支持试运行模式和正常运行模式
|
||||
3. 使用重构后的 storage_services 模块
|
||||
4. 提供统一的配置管理和日志记录
|
||||
|
||||
作者:Lance77
|
||||
日期:2025-11-22
|
||||
"""
|
||||
|
||||
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
|
||||
import os
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
||||
os.environ["LANGCHAIN_TRACING"] = "false"
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 导入重构后的模块
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
|
||||
# 导入数据加载函数
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
get_chunked_dialogs_with_preprocessing,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def main(
|
||||
# Required configuration parameters (no longer from global variables)
|
||||
chunker_strategy: str,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
llm_model_id: str,
|
||||
embedding_model_id: str,
|
||||
# Optional parameters
|
||||
dialogue_text: Optional[str] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本 (Updated to eliminate global variables)
|
||||
|
||||
该函数是重构后的主入口,使用新的模块化架构。
|
||||
Global variables have been eliminated in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
chunker_strategy: Chunking strategy to use (required)
|
||||
group_id: Group ID for the operation (required)
|
||||
user_id: User ID for the operation (required)
|
||||
apply_id: Application ID for the operation (required)
|
||||
llm_model_id: LLM model ID to use (required)
|
||||
embedding_model_id: Embedding model ID to use (required)
|
||||
dialogue_text: 输入的对话文本(可选,用于试运行模式)
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
- False: 正常运行模式,保存到 Neo4j
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
|
||||
工作流程:
|
||||
1. 初始化客户端和配置
|
||||
2. 加载或准备数据
|
||||
3. 执行知识提取流水线
|
||||
4. 保存结果(正常模式)或输出结果(试运行模式)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("MemSci 知识提取流水线 - 重构版本")
|
||||
print("=" * 60)
|
||||
print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}")
|
||||
print("Using chunker strategy:", chunker_strategy)
|
||||
print("Using group ID:", group_id)
|
||||
print("Using model ID:", llm_model_id)
|
||||
print("Using embedding model ID:", embedding_model_id)
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化日志
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
llm_client = get_llm_client(llm_model_id)
|
||||
|
||||
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
embedder_config_dict = get_embedder_config(embedding_model_id)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 加载或准备数据
|
||||
logger.info("Loading data...")
|
||||
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
|
||||
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
|
||||
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
|
||||
step_start = time.time()
|
||||
|
||||
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
|
||||
# 试运行模式:处理前端传入的对话文本
|
||||
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
|
||||
import re
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
# 创建对话上下文和对话数据
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
else:
|
||||
# 正常运行模式:从 testdata.json 文件加载
|
||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||
logger.info("Loading data from testdata.json...")
|
||||
test_data_path = os.path.join(
|
||||
os.path.dirname(__file__), "data", "testdata.json"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_data_path):
|
||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
indices=None,
|
||||
input_data_path=test_data_path,
|
||||
llm_client=llm_client,
|
||||
skip_cleaning=True,
|
||||
)
|
||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
# 从 runtime.json 加载配置(已经过数据库覆写)
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback, # 传递进度回调
|
||||
embedding_id=embedding_model_id, # 传递嵌入模型ID
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
|
||||
# 进度回调:正在知识抽取
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple
|
||||
# extraction_result 是一个包含 7 个元素的 tuple:
|
||||
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
|
||||
# statement_chunk_edges, statement_entity_edges, entity_edges)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# 进度回调:生成结果
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
|
||||
# 步骤 5: 保存结果或输出结果
|
||||
if is_pilot_run:
|
||||
logger.info("Pilot run mode: Skipping Neo4j save")
|
||||
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
|
||||
print("提取结果已生成,可在相关输出中查看。")
|
||||
else:
|
||||
logger.info("Normal mode: Saving to Neo4j...")
|
||||
step_start = time.time()
|
||||
|
||||
# 创建索引和约束
|
||||
try:
|
||||
from app.repositories.neo4j.create_indexes import (
|
||||
create_fulltext_indexes,
|
||||
create_unique_constraints,
|
||||
)
|
||||
await create_fulltext_indexes()
|
||||
await create_unique_constraints()
|
||||
logger.info("Successfully created indexes and constraints")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes/constraints: {e}")
|
||||
|
||||
# 保存数据到 Neo4j
|
||||
try:
|
||||
from app.repositories.neo4j.graph_saver import (
|
||||
save_dialog_and_statements_to_neo4j,
|
||||
)
|
||||
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_edges=entity_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
print("\n✓ 成功保存所有数据到 Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
print("\n⚠ 部分数据保存到 Neo4j 失败")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
|
||||
print(f"\n✗ 保存到 Neo4j 失败: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 6: 生成记忆摘要(可选)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import (
|
||||
add_memory_summary_statement_edges,
|
||||
)
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
|
||||
)
|
||||
|
||||
if not is_pilot_run:
|
||||
# 保存记忆摘要到 Neo4j
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
await ms_connector.close()
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
|
||||
print(f"\n✗ 流水线执行失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 清理资源
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 记录总时间
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# 添加完成标记
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 流水线执行完成")
|
||||
print(f"✓ 总耗时: {total_time:.2f} 秒")
|
||||
print(f"✓ 详细日志: {log_file}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("⚠️ Warning: This script now requires explicit configuration parameters.")
|
||||
print("Global variables have been removed. Please provide configuration parameters.")
|
||||
print("Example usage:")
|
||||
print(" asyncio.run(main(")
|
||||
print(" chunker_strategy='RecursiveChunker',")
|
||||
print(" group_id='your_group_id',")
|
||||
print(" user_id='your_user_id',")
|
||||
print(" apply_id='your_apply_id',")
|
||||
print(" llm_model_id='your_llm_id',")
|
||||
print(" embedding_model_id='your_embedding_id'")
|
||||
print(" ))")
|
||||
|
||||
# This will fail because global variables are removed
|
||||
raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.")
|
||||
@@ -20,14 +20,14 @@ Classes:
|
||||
MemorySummaryNode: Node representing a memory summary
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import re
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.utils.alias_utils import validate_aliases
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
def parse_historical_datetime(v):
|
||||
@@ -361,7 +361,7 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(..., description="Summary of the fact about this entity")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@ Classes:
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
"""Represents an extracted entity from dialogue.
|
||||
|
||||
@@ -5,7 +5,10 @@ import math
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
@@ -14,16 +17,14 @@ from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
|
||||
ForgettingEngine,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_embedder_config,
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_chunk_id,
|
||||
@@ -34,6 +35,7 @@ from app.repositories.neo4j.graph_search import (
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
@@ -324,229 +326,229 @@ def apply_reranker_placeholder(
|
||||
If config enables reranker, annotate items with a final_score equal to combined_score
|
||||
and keep ordering. This is a no-op reranker to be replaced later.
|
||||
"""
|
||||
try:
|
||||
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
if not rc or not rc.get("enabled", False):
|
||||
return results
|
||||
# try:
|
||||
# rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
|
||||
# except Exception as e:
|
||||
# logger.debug(f"Failed to load reranker config: {e}")
|
||||
# rc = {}
|
||||
# if not rc or not rc.get("enabled", False):
|
||||
# return results
|
||||
|
||||
top_k = int(rc.get("top_k", 100))
|
||||
model_name = rc.get("model", "placeholder")
|
||||
# top_k = int(rc.get("top_k", 100))
|
||||
# model_name = rc.get("model", "placeholder")
|
||||
|
||||
for cat, items in results.items():
|
||||
head = items[:top_k]
|
||||
for it in head:
|
||||
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
it["final_score"] = base
|
||||
it["reranker_model"] = model_name
|
||||
# Keep overall order by final_score if present, otherwise combined/score
|
||||
results[cat] = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
reverse=True,
|
||||
)
|
||||
# for cat, items in results.items():
|
||||
# head = items[:top_k]
|
||||
# for it in head:
|
||||
# base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
|
||||
# it["final_score"] = base
|
||||
# it["reranker_model"] = model_name
|
||||
# # Keep overall order by final_score if present, otherwise combined/score
|
||||
# results[cat] = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
|
||||
# reverse=True,
|
||||
# )
|
||||
return results
|
||||
|
||||
|
||||
async def apply_llm_reranker(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
reranker_client: Optional[Any] = None,
|
||||
llm_weight: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Apply LLM-based reranking to search results.
|
||||
# async def apply_llm_reranker(
|
||||
# results: Dict[str, List[Dict[str, Any]]],
|
||||
# query_text: str,
|
||||
# reranker_client: Optional[Any] = None,
|
||||
# llm_weight: Optional[float] = None,
|
||||
# top_k: Optional[int] = None,
|
||||
# batch_size: Optional[int] = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
Args:
|
||||
results: Search results organized by category
|
||||
query_text: Original search query
|
||||
reranker_client: Optional pre-initialized reranker client
|
||||
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
top_k: Maximum number of items to rerank per category
|
||||
batch_size: Number of items to process concurrently
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
# reranker_client: Optional pre-initialized reranker client
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
Returns:
|
||||
Reranked results with final_score and reranker_model fields
|
||||
"""
|
||||
# Load reranker configuration from runtime.json
|
||||
try:
|
||||
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load reranker config: {e}")
|
||||
rc = {}
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
# # Load reranker configuration from runtime.json
|
||||
# # try:
|
||||
# # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
# Check if reranking is enabled
|
||||
enabled = rc.get("enabled", False)
|
||||
if not enabled:
|
||||
logger.debug("LLM reranking is disabled in configuration")
|
||||
return results
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
# Load configuration parameters with defaults
|
||||
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
# Initialize reranker client if not provided
|
||||
if reranker_client is None:
|
||||
try:
|
||||
reranker_client = get_reranker_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
return results
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
# reranker_client = get_reranker_client()
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
# Get model name for metadata
|
||||
model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
# Process each category
|
||||
reranked_results = {}
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
items = results.get(category, [])
|
||||
if not items:
|
||||
reranked_results[category] = []
|
||||
continue
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# items = results.get(category, [])
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
# Select top K items by combined_score for reranking
|
||||
sorted_items = sorted(
|
||||
items,
|
||||
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
reverse=True
|
||||
)
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
top_items = sorted_items[:top_k]
|
||||
remaining_items = sorted_items[top_k:]
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
# Extract text content from each item
|
||||
def extract_text(item: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a result item."""
|
||||
# Try different text fields based on category
|
||||
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
return str(text).strip()
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
# Batch items for concurrent processing
|
||||
batches = []
|
||||
for i in range(0, len(top_items), batch_size):
|
||||
batch = top_items[i:i + batch_size]
|
||||
batches.append(batch)
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
# Process batches concurrently
|
||||
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Process a batch of items with LLM relevance scoring."""
|
||||
scored_batch = []
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
for item in batch:
|
||||
item_text = extract_text(item)
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
# Skip items with no text
|
||||
if not item_text:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = 0.0
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
continue
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["llm_relevance_score"] = 0.0
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
# Create relevance scoring prompt
|
||||
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
Query: {query_text}
|
||||
# Query: {query_text}
|
||||
|
||||
Result: {item_text}
|
||||
# Result: {item_text}
|
||||
|
||||
Respond with only a number between 0.0 and 1.0, where:
|
||||
- 0.0 means completely irrelevant
|
||||
- 1.0 means perfectly relevant
|
||||
# Respond with only a number between 0.0 and 1.0, where:
|
||||
# - 0.0 means completely irrelevant
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
Relevance score:"""
|
||||
# Relevance score:"""
|
||||
|
||||
# Send request to LLM
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = await reranker_client.chat(messages)
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
# Parse LLM response to extract relevance score
|
||||
response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
# Try to extract a float from the response
|
||||
try:
|
||||
# Remove any non-numeric characters except decimal point
|
||||
import re
|
||||
score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
if score_match:
|
||||
llm_score = float(score_match.group(1))
|
||||
# Clamp to [0.0, 1.0]
|
||||
llm_score = max(0.0, min(1.0, llm_score))
|
||||
else:
|
||||
raise ValueError("No numeric score found in response")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
llm_score = None
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
# import re
|
||||
# score_match = re.search(r'(\d+\.?\d*)', response_text)
|
||||
# if score_match:
|
||||
# llm_score = float(score_match.group(1))
|
||||
# # Clamp to [0.0, 1.0]
|
||||
# llm_score = max(0.0, min(1.0, llm_score))
|
||||
# else:
|
||||
# raise ValueError("No numeric score found in response")
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
# Calculate final score
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
if llm_score is not None:
|
||||
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
item_copy["llm_relevance_score"] = llm_score
|
||||
else:
|
||||
# Use combined_score as fallback
|
||||
final_score = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
# else:
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
item_copy["final_score"] = final_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["llm_relevance_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_batch.append(item_copy)
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
return scored_batch
|
||||
# return scored_batch
|
||||
|
||||
# Process all batches concurrently
|
||||
try:
|
||||
batch_tasks = [process_batch(batch) for batch in batches]
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Merge batch results
|
||||
scored_items = []
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"Batch processing failed: {result}")
|
||||
continue
|
||||
scored_items.extend(result)
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
# if isinstance(result, Exception):
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
# Add remaining items (not in top K) with their combined_score as final_score
|
||||
for item in remaining_items:
|
||||
item_copy = item.copy()
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item_copy["final_score"] = combined_score
|
||||
item_copy["reranker_model"] = model_name
|
||||
scored_items.append(item_copy)
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
# Sort all items by final_score in descending order
|
||||
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
reranked_results[category] = scored_items
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# Return original items with combined_score as final_score
|
||||
for item in items:
|
||||
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
item["final_score"] = combined_score
|
||||
item["reranker_model"] = model_name
|
||||
reranked_results[category] = items
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
# for item in items:
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
return reranked_results
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
@@ -556,7 +558,7 @@ async def run_hybrid_search(
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
embedding_id: str,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -564,11 +566,14 @@ async def run_hybrid_search(
|
||||
"""
|
||||
|
||||
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing embedding_model_id and config_id
|
||||
"""
|
||||
# Start overall timing
|
||||
search_start_time = time.time()
|
||||
latency_metrics = {}
|
||||
logger.info(f"using embedding_id:{embedding_id}...")
|
||||
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
@@ -621,7 +626,9 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
embedder_config_dict = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
@@ -683,7 +690,7 @@ async def run_hybrid_search(
|
||||
if use_forgetting_rerank:
|
||||
# Load forgetting parameters from pipeline config
|
||||
try:
|
||||
pc = get_pipeline_config()
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
@@ -711,16 +718,16 @@ async def run_hybrid_search(
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
if use_llm_rerank:
|
||||
try:
|
||||
reranked_results = await apply_llm_reranker(
|
||||
results=reranked_results,
|
||||
query_text=query_text,
|
||||
)
|
||||
llm_rerank_applied = True
|
||||
logger.info("LLM reranking applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
# if use_llm_rerank:
|
||||
# try:
|
||||
# reranked_results = await apply_llm_reranker(
|
||||
# results=reranked_results,
|
||||
# query_text=query_text,
|
||||
# )
|
||||
# llm_rerank_applied = True
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
@@ -896,90 +903,95 @@ async def search_chunk_by_chunk_id(
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the hybrid graph search CLI.
|
||||
# def main():
|
||||
# """Main entry point for the hybrid graph search CLI.
|
||||
|
||||
Parses command line arguments and executes search with specified parameters.
|
||||
Supports keyword, embedding, and hybrid search modes.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
parser.add_argument(
|
||||
"--query", "-q", required=True, help="Free-text query to search"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search-type",
|
||||
"-t",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
default="hybrid",
|
||||
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-name",
|
||||
"-m",
|
||||
default="openai/nomic-embed-text:v1.5",
|
||||
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-id",
|
||||
"-g",
|
||||
default=None,
|
||||
help="Optional group_id to filter results (default: None)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
"-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Max number of results per type (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
"-i",
|
||||
nargs="+",
|
||||
default=["statements", "chunks", "entities", "summaries"],
|
||||
choices=["statements", "chunks", "entities", "summaries"],
|
||||
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="search_results.json",
|
||||
help="Path to save the search results JSON (default: search_results.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rerank-alpha",
|
||||
"-a",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forgetting-rerank",
|
||||
action="store_true",
|
||||
help="Apply forgetting curve during reranking for hybrid search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-rerank",
|
||||
action="store_true",
|
||||
help="Apply LLM-based reranking for hybrid search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# Parses command line arguments and executes search with specified parameters.
|
||||
# Supports keyword, embedding, and hybrid search modes.
|
||||
# """
|
||||
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
# parser.add_argument(
|
||||
# "--query", "-q", required=True, help="Free-text query to search"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--search-type",
|
||||
# "-t",
|
||||
# choices=["keyword", "embedding", "hybrid"],
|
||||
# default="hybrid",
|
||||
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--config-id",
|
||||
# "-c",
|
||||
# type=int,
|
||||
# required=True,
|
||||
# help="Database configuration ID (required)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--group-id",
|
||||
# "-g",
|
||||
# default=None,
|
||||
# help="Optional group_id to filter results (default: None)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--limit",
|
||||
# "-k",
|
||||
# type=int,
|
||||
# default=5,
|
||||
# help="Max number of results per type (default: 5)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--include",
|
||||
# "-i",
|
||||
# nargs="+",
|
||||
# default=["statements", "chunks", "entities", "summaries"],
|
||||
# choices=["statements", "chunks", "entities", "summaries"],
|
||||
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--output",
|
||||
# "-o",
|
||||
# default="search_results.json",
|
||||
# help="Path to save the search results JSON (default: search_results.json)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--rerank-alpha",
|
||||
# "-a",
|
||||
# type=float,
|
||||
# default=0.6,
|
||||
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--forgetting-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply forgetting curve during reranking for hybrid search.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--llm-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply LLM-based reranking for hybrid search.",
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
run_hybrid_search(
|
||||
query_text=args.query,
|
||||
search_type=args.search_type,
|
||||
group_id=args.group_id,
|
||||
limit=args.limit,
|
||||
include=args.include,
|
||||
output_path=args.output,
|
||||
embedding_id=config_defs.SELECTED_EMBEDDING_ID,
|
||||
rerank_alpha=args.rerank_alpha,
|
||||
use_forgetting_rerank=args.forgetting_rerank,
|
||||
use_llm_rerank=args.llm_rerank,
|
||||
)
|
||||
)
|
||||
# # Load memory config from database
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
|
||||
|
||||
# asyncio.run(
|
||||
# run_hybrid_search(
|
||||
# query_text=args.query,
|
||||
# search_type=args.search_type,
|
||||
# group_id=args.group_id,
|
||||
# limit=args.limit,
|
||||
# include=args.include,
|
||||
# output_path=args.output,
|
||||
# memory_config=memory_config,
|
||||
# rerank_alpha=args.rerank_alpha,
|
||||
# use_forgetting_rerank=args.forgetting_rerank,
|
||||
# use_llm_rerank=args.llm_rerank,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
去重功能函数
|
||||
"""
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from typing import List, Dict, Tuple, Any
|
||||
from app.core.memory.models.graph_models import(
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode
|
||||
)
|
||||
import os
|
||||
from datetime import datetime
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
"""统一实体类型:基于LLM建议或启发式规则选择最合适的类型。
|
||||
@@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
config: DedupConfig,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
|
||||
"""
|
||||
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
|
||||
@@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
"""
|
||||
llm_records: List[str] = []
|
||||
try:
|
||||
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
|
||||
enable_switch = (
|
||||
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
|
||||
)
|
||||
if not enable_switch:
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
|
||||
_defaults = DedupConfig()
|
||||
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
|
||||
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
|
||||
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
|
||||
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
|
||||
|
||||
# 动态导入 llm 客户端(修正导入路径)
|
||||
try:
|
||||
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
|
||||
get_llm_client_fn = llm_utils_mod.get_llm_client
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
|
||||
if not bool(config.enable_llm_dedup_blockwise):
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
# 从配置读取 LLM 迭代参数
|
||||
block_size = config.llm_block_size
|
||||
block_concurrency = config.llm_block_concurrency
|
||||
pair_concurrency = config.llm_pair_concurrency
|
||||
max_rounds = config.llm_max_rounds
|
||||
|
||||
try:
|
||||
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
|
||||
@@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
# 获取 LLM 客户端
|
||||
try:
|
||||
llm_client = get_llm_client_fn()
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
|
||||
# 验证 LLM 客户端
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端未提供")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
llm_redirect, llm_records = await llm_fn(
|
||||
@@ -813,7 +799,8 @@ async def LLM_disamb_decision(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
id_redirect: Dict[str, str],
|
||||
config: DedupConfig | None = None,
|
||||
config: DedupConfig,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
|
||||
"""
|
||||
预消歧阶段:对“同名但类型不同”的实体对调用LLM进行消歧,
|
||||
@@ -824,22 +811,16 @@ async def LLM_disamb_decision(
|
||||
disamb_records: List[str] = []
|
||||
blocked_pairs: set[tuple[str, str]] = set()
|
||||
try:
|
||||
enable_switch = (
|
||||
config.enable_llm_disambiguation
|
||||
if config is not None
|
||||
else DedupConfig().enable_llm_disambiguation
|
||||
)
|
||||
if not bool(enable_switch):
|
||||
if not bool(config.enable_llm_disambiguation):
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
|
||||
llm_disambiguate_pairs_iterative,
|
||||
)
|
||||
|
||||
# 获取 LLM 客户端并验证
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
# 验证 LLM 客户端
|
||||
if llm_client is None:
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None")
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端未提供")
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
|
||||
@@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges(
|
||||
report_append: bool = False,
|
||||
report_stage_notes: List[str] | None = None,
|
||||
dedup_config: DedupConfig | None = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[
|
||||
List[ExtractedEntityNode],
|
||||
List[StatementEntityEdge],
|
||||
@@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges(
|
||||
|
||||
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
|
||||
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
|
||||
)
|
||||
|
||||
# 2) 模糊匹配(本地规则)
|
||||
@@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges(
|
||||
|
||||
if should_trigger_llm:
|
||||
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
|
||||
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
|
||||
)
|
||||
else:
|
||||
llm_decision_records = []
|
||||
|
||||
@@ -10,15 +10,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明
|
||||
_write_dedup_fusion_report,
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
get_dedup_candidates_for_entities, # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from app.repositories.neo4j.neo4j_connector import (
|
||||
Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
|
||||
)
|
||||
|
||||
|
||||
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt,用于将任意类型的输入值解析为datetime对象(处理实体节点中的时间字段)
|
||||
@@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
dedup_config: DedupConfig | None = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
@@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
|
||||
|
||||
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges(
|
||||
union_entities,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
report_stage="第二层去重消歧",
|
||||
report_append=True,
|
||||
dedup_config=dedup_config,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
@@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: Optional[ExtractionPipelineConfig] = None,
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return(
|
||||
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
|
||||
"""
|
||||
|
||||
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
|
||||
# pipeline_config is required - caller must provide it
|
||||
if pipeline_config is None:
|
||||
try:
|
||||
pipeline_config = get_pipeline_config()
|
||||
except Exception:
|
||||
pipeline_config = None
|
||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
@@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
report_stage="第一层去重消歧",
|
||||
report_append=False,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
# 初始化第二层融合结果为第一层结果
|
||||
@@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
|
||||
llm_client=llm_client,
|
||||
)
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
|
||||
@@ -1253,6 +1253,7 @@ class ExtractionOrchestrator:
|
||||
report_stage="第一层去重消歧(试运行)",
|
||||
report_append=False,
|
||||
dedup_config=self.config.deduplication,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
@@ -1284,6 +1285,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
self.config,
|
||||
self.connector,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 解包返回值
|
||||
|
||||
@@ -5,11 +5,13 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
@@ -21,7 +23,9 @@ class EmbeddingGenerator:
|
||||
Args:
|
||||
embedding_id: 嵌入模型 ID
|
||||
"""
|
||||
embedder_config = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config = config_service.get_embedder_config(embedding_id)
|
||||
self.embedder_client = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_config),
|
||||
)
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
@@ -91,22 +87,17 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
|
||||
async def Memory_summary_generation(
|
||||
async def memory_summary_generation(
|
||||
chunked_dialogs: List[DialogData],
|
||||
llm_client,
|
||||
embedding_id,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
) -> List[MemorySummaryNode]:
|
||||
"""Generate memory summaries per chunk, embed them, and return nodes."""
|
||||
embedder_cfg_dict = get_embedder_config(embedding_id)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
|
||||
)
|
||||
|
||||
# Collect all tasks for parallel processing
|
||||
tasks = []
|
||||
for dialog in chunked_dialogs:
|
||||
for chunk in dialog.chunks:
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder))
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
|
||||
|
||||
# Process all chunks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
|
||||
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||
from app.core.memory.utils.data.ontology import (
|
||||
LABEL_DEFINITIONS,
|
||||
RelevenceInfo,
|
||||
StatementType,
|
||||
TemporalInfo,
|
||||
)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,21 +8,25 @@
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
neo4j_query_all,
|
||||
neo4j_query_part,
|
||||
neo4j_statement_all,
|
||||
neo4j_statement_part,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
@@ -135,14 +139,20 @@ class ReflectionEngine:
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
if self.llm_client is None:
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
model_id = self.llm_client
|
||||
self.llm_client = get_llm_client(model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(model_id)
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
@@ -154,11 +164,15 @@ class ReflectionEngine:
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
)
|
||||
self.render_evaluate_prompt_func = render_evaluate_prompt
|
||||
|
||||
if self.render_reflexion_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
self.render_reflexion_prompt_func = render_reflexion_prompt
|
||||
|
||||
if self.conflict_schema is None:
|
||||
@@ -170,7 +184,9 @@ class ReflectionEngine:
|
||||
self.reflexion_schema = ReflexionResultSchema
|
||||
|
||||
if self.update_query is None:
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
)
|
||||
self.update_query = UPDATE_STATEMENT_INVALID_AT
|
||||
|
||||
self._lazy_init_done = True
|
||||
|
||||
@@ -4,10 +4,20 @@
|
||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.storage_services.search.semantic_search import (
|
||||
SemanticSearchStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SearchStrategy",
|
||||
@@ -34,7 +44,7 @@ async def run_hybrid_search(
|
||||
include: list[str] | None = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
embedding_id: str | None = None,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""运行混合搜索(向后兼容的函数式API)
|
||||
@@ -51,24 +61,26 @@ async def run_hybrid_search(
|
||||
include: 要包含的搜索类别列表
|
||||
alpha: BM25分数权重(0.0-1.0)
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
embedding_id: 嵌入模型ID
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 使用提供的embedding_id或默认值
|
||||
emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID
|
||||
if not memory_config:
|
||||
raise ValueError("memory_config is required for search")
|
||||
|
||||
# 初始化客户端
|
||||
connector = Neo4jConnector()
|
||||
embedder_config_dict = get_embedder_config(emb_id)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
|
||||
@@ -5,15 +5,20 @@
|
||||
使用余弦相似度进行语义匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
@@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
"""
|
||||
try:
|
||||
# 从数据库读取嵌入器配置
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
|
||||
@@ -9,7 +9,6 @@ from .config_utils import (
|
||||
get_chunker_config,
|
||||
get_embedder_config,
|
||||
get_model_config,
|
||||
get_neo4j_config,
|
||||
get_picture_config,
|
||||
get_pipeline_config,
|
||||
get_pruning_config,
|
||||
@@ -41,7 +40,6 @@ __all__ = [
|
||||
# config_utils
|
||||
"get_model_config",
|
||||
"get_embedder_config",
|
||||
"get_neo4j_config",
|
||||
"get_chunker_config",
|
||||
"get_pipeline_config",
|
||||
"get_pruning_config",
|
||||
|
||||
@@ -1,90 +1,74 @@
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import CONFIG
|
||||
from app.db import get_db
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
"""
|
||||
Configuration utilities - Backward compatibility layer
|
||||
|
||||
DEPRECATED: These functions now require a db session parameter.
|
||||
New code should use MemoryConfigService(db) instance directly.
|
||||
|
||||
For functions that don't require db (get_pipeline_config, get_pruning_config),
|
||||
they are still re-exported here.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# These functions don't require db - safe to re-export as static methods
|
||||
get_pipeline_config = MemoryConfigService.get_pipeline_config
|
||||
get_pruning_config = MemoryConfigService.get_pruning_config
|
||||
|
||||
|
||||
def get_model_config(model_id: str, db: Session | None = None) -> dict:
|
||||
def get_model_config(model_id: str, db=None):
|
||||
"""DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly."""
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
raise ValueError(
|
||||
"get_model_config now requires a db session. "
|
||||
"Use MemoryConfigService(db).get_model_config(model_id) directly."
|
||||
)
|
||||
return MemoryConfigService(db).get_model_config(model_id)
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
print(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
|
||||
# 从环境变量读取超时和重试配置
|
||||
from app.core.config import settings
|
||||
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免 LLM 请求超时
|
||||
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取,默认120秒
|
||||
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取,默认2次
|
||||
}
|
||||
# 写入model_config.log文件中
|
||||
with open("logs/model_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"模型ID: {model_id}\n")
|
||||
f.write(f"模型配置信息:\n{model_config}\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
|
||||
def get_embedder_config(embedding_id: str, db=None):
|
||||
"""DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."""
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
raise ValueError(
|
||||
"get_embedder_config now requires a db session. "
|
||||
"Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."
|
||||
)
|
||||
return MemoryConfigService(db).get_embedder_config(embedding_id)
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
|
||||
if not config:
|
||||
print(f"嵌入模型ID {embedding_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
# Ensure required field for RedBearModelConfig validation
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免嵌入服务请求超时
|
||||
"timeout": 120.0, # 嵌入服务超时时间(秒)
|
||||
"max_retries": 5, # 最大重试次数
|
||||
}
|
||||
# 写入embedder_config.log文件中
|
||||
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"嵌入模型ID: {embedding_id}\n")
|
||||
f.write(f"嵌入模型配置信息:\n{model_config}\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_neo4j_config() -> dict:
|
||||
"""Retrieves the Neo4j configuration from the config file."""
|
||||
return CONFIG.get("neo4j", {})
|
||||
def get_picture_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
"""Retrieves the configuration for a specific model from the config file.
|
||||
|
||||
.. deprecated::
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use database-backed model configuration instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_picture_config is deprecated and will be removed in a future version. "
|
||||
"Use database-backed model configuration instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
for model_config in CONFIG.get("picture_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
raise ValueError(f"Model '{llm_name}' not found in config.json")
|
||||
|
||||
|
||||
def get_voice_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
"""Retrieves the configuration for a specific model from the config file.
|
||||
|
||||
.. deprecated::
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use database-backed model configuration instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"get_voice_config is deprecated and will be removed in a future version. "
|
||||
"Use database-backed model configuration instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
for model_config in CONFIG.get("voice_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
@@ -92,19 +76,8 @@ def get_voice_config(llm_name: str) -> dict:
|
||||
|
||||
|
||||
def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
"""Retrieves the configuration for a specific chunker strategy.
|
||||
"""Retrieves the configuration for a specific chunker strategy."""
|
||||
|
||||
Enhancements:
|
||||
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
|
||||
- Falls back to the first available chunker config when the requested one is missing.
|
||||
"""
|
||||
# 1) Try to find exact match in config
|
||||
chunker_list = CONFIG.get("chunker_list", [])
|
||||
for chunker_config in chunker_list:
|
||||
if chunker_config.get("chunker_strategy") == chunker_strategy:
|
||||
return chunker_config
|
||||
|
||||
# 2) Provide sane defaults for newer strategies
|
||||
default_configs = {
|
||||
"RecursiveChunker": {
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
@@ -112,7 +85,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
|
||||
"LLMChunker": {
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
@@ -137,127 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
if chunker_strategy in default_configs:
|
||||
return default_configs[chunker_strategy]
|
||||
|
||||
# 3) Fallback: use first available config but tag with requested strategy
|
||||
if chunker_list:
|
||||
fallback = chunker_list[0].copy()
|
||||
fallback["chunker_strategy"] = chunker_strategy
|
||||
# Non-fatal notice for visibility in logs if any
|
||||
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
|
||||
return fallback
|
||||
|
||||
# 4) If no configs available at all
|
||||
raise ValueError(
|
||||
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
|
||||
f"Chunker '{chunker_strategy}' not found "
|
||||
)
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
def get_pipeline_config(
|
||||
config_id: int,
|
||||
db: Session | None = None,
|
||||
) -> ExtractionPipelineConfig:
|
||||
"""Build ExtractionPipelineConfig from database.
|
||||
|
||||
Args:
|
||||
config_id: Database configuration ID (required). Loads pipeline
|
||||
settings from the DataConfig table.
|
||||
db: Optional database session. If not provided, a new session
|
||||
will be created.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings loaded from database.
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id not found in database.
|
||||
"""
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
# Load from database
|
||||
if db is None:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
db_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"Configuration {config_id} not found in database")
|
||||
|
||||
# Build DedupConfig from database
|
||||
dedup_kwargs = {
|
||||
"enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False,
|
||||
"enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False,
|
||||
}
|
||||
|
||||
# Fuzzy thresholds
|
||||
if db_config.t_name_strict is not None:
|
||||
dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict
|
||||
if db_config.t_type_strict is not None:
|
||||
dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict
|
||||
if db_config.t_overall is not None:
|
||||
dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall
|
||||
|
||||
dedup_config = DedupConfig(**dedup_kwargs)
|
||||
|
||||
# Build StatementExtractionConfig from database
|
||||
stmt_kwargs = {}
|
||||
if db_config.statement_granularity is not None:
|
||||
stmt_kwargs["statement_granularity"] = db_config.statement_granularity
|
||||
if db_config.include_dialogue_context is not None:
|
||||
stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context)
|
||||
if db_config.max_context is not None:
|
||||
stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context
|
||||
|
||||
stmt_config = StatementExtractionConfig(**stmt_kwargs)
|
||||
|
||||
# Build ForgettingEngineConfig from database
|
||||
forget_kwargs = {}
|
||||
if db_config.offset is not None:
|
||||
forget_kwargs["offset"] = db_config.offset
|
||||
if db_config.lambda_time is not None:
|
||||
forget_kwargs["lambda_time"] = db_config.lambda_time
|
||||
if db_config.lambda_mem is not None:
|
||||
forget_kwargs["lambda_mem"] = db_config.lambda_mem
|
||||
|
||||
forget_config = ForgettingEngineConfig(**forget_kwargs)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
|
||||
def get_pruning_config(
|
||||
config_id: int,
|
||||
db: Session | None = None,
|
||||
) -> dict:
|
||||
"""Retrieve semantic pruning config from database.
|
||||
|
||||
Args:
|
||||
config_id: Database configuration ID (required).
|
||||
db: Optional database session.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str ("education" | "online_service" | "outbound")
|
||||
- pruning_threshold: float (0-0.9)
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id not found in database.
|
||||
"""
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
if db is None:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
db_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"Configuration {config_id} not found in database")
|
||||
|
||||
return {
|
||||
"pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False,
|
||||
"pruning_scene": db_config.pruning_scene or "education",
|
||||
"pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5,
|
||||
}
|
||||
|
||||
@@ -1,269 +1,268 @@
|
||||
"""
|
||||
配置加载模块 - DEPRECATED
|
||||
# """
|
||||
# 配置加载模块 - DEPRECATED
|
||||
|
||||
⚠️ DEPRECATION NOTICE ⚠️
|
||||
This module is deprecated and will be removed in a future version.
|
||||
Global configuration variables have been eliminated in favor of dependency injection.
|
||||
# ⚠️ DEPRECATION NOTICE ⚠️
|
||||
# This module is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
Use the new MemoryConfig system instead:
|
||||
- app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
|
||||
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
|
||||
|
||||
阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# """
|
||||
# import json
|
||||
# import os
|
||||
# import threading
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Any, Dict, Optional
|
||||
|
||||
#TODO: Fix this
|
||||
# #TODO: Fix this
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
# try:
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# Import unified configuration system
|
||||
try:
|
||||
from app.core.config import settings
|
||||
USE_UNIFIED_CONFIG = True
|
||||
except ImportError:
|
||||
USE_UNIFIED_CONFIG = False
|
||||
settings = None
|
||||
# # Import unified configuration system
|
||||
# try:
|
||||
# from app.core.config import settings
|
||||
# USE_UNIFIED_CONFIG = True
|
||||
# except ImportError:
|
||||
# USE_UNIFIED_CONFIG = False
|
||||
# settings = None
|
||||
|
||||
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# __file__ = app/core/memory/utils/config/definitions.py
|
||||
# os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# os.path.dirname(...) = app/core/memory/utils
|
||||
# os.path.dirname(...) = app/core/memory
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# # __file__ = app/core/memory/utils/config/definitions.py
|
||||
# # os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# # os.path.dirname(...) = app/core/memory/utils
|
||||
# # os.path.dirname(...) = app/core/memory
|
||||
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# DEPRECATED: Global configuration lock removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
# # DEPRECATED: Global configuration lock removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
|
||||
# DEPRECATED: Legacy config.json loading removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
CONFIG = {}
|
||||
# # DEPRECATED: Legacy config.json loading removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# CONFIG = {}
|
||||
|
||||
DEFAULT_VALUES = {
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"group_id": "group_123",
|
||||
"user_id": "default_user",
|
||||
"apply_id": "default_apply",
|
||||
"llm_agent_name": "openai/qwen-plus",
|
||||
"llm_verify_name": "openai/qwen-plus",
|
||||
"llm_image_recognition": "openai/qwen-plus",
|
||||
"llm_voice_recognition": "openai/qwen-plus",
|
||||
"prompt_level": "DEBUG",
|
||||
"reflexion_iteration_period": "3",
|
||||
"reflexion_range": "retrieval",
|
||||
"reflexion_baseline": "TIME",
|
||||
}
|
||||
# DEFAULT_VALUES = {
|
||||
# "llm_name": "openai/qwen-plus",
|
||||
# "embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
# "chunker_strategy": "RecursiveChunker",
|
||||
# "group_id": "group_123",
|
||||
# "user_id": "default_user",
|
||||
# "apply_id": "default_apply",
|
||||
# "llm_agent_name": "openai/qwen-plus",
|
||||
# "llm_verify_name": "openai/qwen-plus",
|
||||
# "llm_image_recognition": "openai/qwen-plus",
|
||||
# "llm_voice_recognition": "openai/qwen-plus",
|
||||
# "prompt_level": "DEBUG",
|
||||
# "reflexion_iteration_period": "3",
|
||||
# "reflexion_range": "retrieval",
|
||||
# "reflexion_baseline": "TIME",
|
||||
# }
|
||||
|
||||
# DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# These will be removed in a future version
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
# # DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
"""
|
||||
DEPRECATED: Legacy runtime.json loading
|
||||
# # 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
# def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
# """
|
||||
# DEPRECATED: Legacy runtime.json loading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Empty configuration (legacy support only)
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return {"selections": {}}
|
||||
# Returns:
|
||||
# Dict[str, Any]: Empty configuration (legacy support only)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return {"selections": {}}
|
||||
|
||||
|
||||
# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# 保留此函数仅为向后兼容
|
||||
def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
DEPRECATED: Legacy database configuration loading
|
||||
# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# # 保留此函数仅为向后兼容
|
||||
# def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# DEPRECATED: Legacy database configuration loading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return None
|
||||
# Returns:
|
||||
# Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
"""
|
||||
DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
# """
|
||||
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Global configuration variables have been eliminated in favor of dependency injection.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
Use the new MemoryConfig system instead:
|
||||
- app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
- Pass configuration objects as parameters instead of using global variables
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
# - Pass configuration objects as parameters instead of using global variables
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Args:
|
||||
# runtime_cfg: 运行时配置字典
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# Keep minimal global state for backward compatibility only
|
||||
# These will be removed in a future version
|
||||
global RUNTIME_CONFIG, SELECTIONS
|
||||
# # Keep minimal global state for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# global RUNTIME_CONFIG, SELECTIONS
|
||||
|
||||
RUNTIME_CONFIG = runtime_cfg
|
||||
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
# RUNTIME_CONFIG = runtime_cfg
|
||||
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# All other global variables have been removed
|
||||
# Use MemoryConfig objects instead
|
||||
# # All other global variables have been removed
|
||||
# # Use MemoryConfig objects instead
|
||||
|
||||
|
||||
# 初始化:使用统一配置加载器
|
||||
def _initialize_configuration() -> None:
|
||||
"""
|
||||
DEPRECATED: Legacy configuration initialization
|
||||
# # 初始化:使用统一配置加载器
|
||||
# def _initialize_configuration() -> None:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration initialization
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Initialize with empty configuration for backward compatibility
|
||||
_expose_runtime_constants({"selections": {}})
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Initialize with empty configuration for backward compatibility
|
||||
# _expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# 模块加载时自动初始化配置
|
||||
_initialize_configuration()
|
||||
# # 模块加载时自动初始化配置
|
||||
# _initialize_configuration()
|
||||
|
||||
# DEPRECATED: Global variables removed
|
||||
# These variables have been eliminated in favor of dependency injection
|
||||
# Use MemoryConfig objects instead of accessing global variables
|
||||
# # DEPRECATED: Global variables removed
|
||||
# # These variables have been eliminated in favor of dependency injection
|
||||
# # Use MemoryConfig objects instead of accessing global variables
|
||||
|
||||
|
||||
# 公共 API:动态重新加载配置
|
||||
def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
"""
|
||||
DEPRECATED: Legacy configuration reloading
|
||||
# # 公共 API:动态重新加载配置
|
||||
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration reloading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
For new code, use:
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID (deprecated)
|
||||
force_reload: Force reload flag (deprecated)
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
# force_reload: Force reload flag (deprecated)
|
||||
|
||||
Returns:
|
||||
bool: Always returns False (deprecated functionality)
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
warnings.warn(
|
||||
"reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# warnings.warn(
|
||||
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
"Use MemoryConfig objects with dependency injection instead.")
|
||||
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
return False
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_current_config_id() -> Optional[str]:
|
||||
"""
|
||||
DEPRECATED: Legacy config ID retrieval
|
||||
# def get_current_config_id() -> Optional[str]:
|
||||
# """
|
||||
# DEPRECATED: Legacy config ID retrieval
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None (deprecated functionality)
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return None
|
||||
# Returns:
|
||||
# Optional[str]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
def ensure_fresh_config(config_id = None) -> bool:
|
||||
"""
|
||||
DEPRECATED: Legacy configuration freshness check
|
||||
# def ensure_fresh_config(config_id = None) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration freshness check
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
For new code, use:
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID (deprecated)
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
|
||||
Returns:
|
||||
bool: Always returns False (deprecated functionality)
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
warnings.warn(
|
||||
"ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# warnings.warn(
|
||||
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
"Use MemoryConfig objects with dependency injection instead.")
|
||||
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
return False
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,9 @@ This module provides centralized functions for creating embedder clients.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -66,7 +67,8 @@ def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
|
||||
raise ValueError("Embedding ID is required but was not provided")
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_id)
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -13,105 +13,225 @@ async def handle_response(response: type[BaseModel]) -> dict:
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
class MemoryClientFactory:
|
||||
"""
|
||||
Get LLM client from MemoryConfig object.
|
||||
Factory for creating LLM, embedder, and reranker clients.
|
||||
|
||||
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
|
||||
This ensures proper configuration management and multi-tenant support.
|
||||
Initialize once with db session, then call methods without passing db each time.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing llm_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized LLM client
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM model ID is not configured or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> llm_client = get_llm_client_from_config(memory_config)
|
||||
>>> factory = MemoryClientFactory(db)
|
||||
>>> llm_client = factory.get_llm_client(model_id)
|
||||
>>> embedder_client = factory.get_embedder_client(embedding_id)
|
||||
"""
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no LLM model configured"
|
||||
)
|
||||
return get_llm_client(str(memory_config.llm_model_id))
|
||||
|
||||
def __init__(self, db: Session):
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
self._config_service = MemoryConfigService(db)
|
||||
|
||||
def get_llm_client(self, llm_id: str) -> OpenAIClient:
|
||||
"""Get LLM client by model ID."""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required")
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_embedder_client(self, embedding_id: str):
|
||||
"""Get embedder client by model ID."""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
if not embedding_id:
|
||||
raise ValueError("Embedding ID is required")
|
||||
|
||||
try:
|
||||
embedder_config = self._config_service.get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIEmbedderClient(
|
||||
RedBearModelConfig(
|
||||
model_name=embedder_config.get("model_name"),
|
||||
provider=embedder_config.get("provider"),
|
||||
api_key=embedder_config.get("api_key"),
|
||||
base_url=embedder_config.get("base_url")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = embedder_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
|
||||
"""Get reranker client by model ID."""
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required")
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""Get LLM client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing llm_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the LLM model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no LLM model configured
|
||||
"""
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no LLM model configured"
|
||||
)
|
||||
return self.get_llm_client(str(memory_config.llm_model_id))
|
||||
|
||||
def get_embedder_client_from_config(self, memory_config: "MemoryConfig"):
|
||||
"""Get embedder client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing embedding_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient configured for the embedding model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no embedding model configured
|
||||
"""
|
||||
if not memory_config.embedding_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no embedding model configured"
|
||||
)
|
||||
return self.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""Get reranker client from MemoryConfig object.
|
||||
|
||||
Args:
|
||||
memory_config: Configuration containing rerank_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If memory_config has no rerank model configured
|
||||
"""
|
||||
if not memory_config.rerank_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no rerank model configured"
|
||||
)
|
||||
return self.get_reranker_client(str(memory_config.rerank_model_id))
|
||||
|
||||
|
||||
def get_llm_client(llm_id: str):
|
||||
"""
|
||||
Get LLM client by model ID.
|
||||
# Legacy functions for backward compatibility
|
||||
def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient:
|
||||
"""Get LLM client from MemoryConfig object.
|
||||
|
||||
**LEGACY/TEST METHOD**: Use this function only for:
|
||||
- Test/evaluation code where you have a model ID directly
|
||||
- Legacy code that hasn't been migrated to MemoryConfig yet
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead.
|
||||
|
||||
For production code with MemoryConfig, use get_llm_client_from_config() instead.
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_llm_client_from_config method directly.
|
||||
|
||||
Args:
|
||||
llm_id: LLM model ID (required)
|
||||
memory_config: Configuration containing llm_model_id
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized LLM client
|
||||
OpenAIClient configured for the LLM model
|
||||
|
||||
Raises:
|
||||
ValueError: If llm_id is not provided or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> # For tests/evaluations only
|
||||
>>> llm_client = get_llm_client("model-uuid-string")
|
||||
ValueError: If memory_config has no LLM model configured
|
||||
"""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
return llm_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
return MemoryClientFactory(db).get_llm_client_from_config(memory_config)
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str):
|
||||
"""
|
||||
Get an LLM client configured for reranking.
|
||||
def get_llm_client(llm_id: str, db: Session) -> OpenAIClient:
|
||||
"""Get LLM client by model ID.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_llm_client method directly.
|
||||
|
||||
Args:
|
||||
rerank_id: Reranker model ID (required)
|
||||
llm_id: LLM model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized client for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If rerank_id is not provided or client initialization fails
|
||||
OpenAIClient configured for the LLM model
|
||||
"""
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required but was not provided")
|
||||
return MemoryClientFactory(db).get_llm_client(llm_id)
|
||||
|
||||
|
||||
def get_embedder_client(embedding_id: str, db: Session):
|
||||
"""Get embedder client by model ID.
|
||||
|
||||
try:
|
||||
model_config = get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead.
|
||||
|
||||
try:
|
||||
reranker_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
return reranker_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_embedder_client method directly.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient configured for the embedding model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_embedder_client(embedding_id)
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient:
|
||||
"""Get reranker client by model ID.
|
||||
|
||||
DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead.
|
||||
|
||||
This function is maintained for backward compatibility during migration to the
|
||||
factory pattern. New code should create a MemoryClientFactory instance and use
|
||||
its get_reranker_client method directly.
|
||||
|
||||
Args:
|
||||
rerank_id: Reranker model ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OpenAIClient configured for the reranker model
|
||||
"""
|
||||
return MemoryClientFactory(db).get_reranker_client(rerank_id)
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]:
|
||||
冲突记忆列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]:
|
||||
反思结果列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
Reference in New Issue
Block a user