refactor(memory): restructure memory agent and config management
- Reorganize imports and remove unused dependencies across memory agent controllers - Extract config validation logic into dedicated validators module - Create new memory_config_model and memory_config_schema for configuration management - Implement memory_config_service for centralized config handling - Add embedder_utils module for embedding model utilities - Refactor memory agent service to use new config validation framework - Clean up configuration files (remove config.json, testdata.json, dbrun.json) - Remove deprecated hybrid_chatbot.py and config overrides - Update logging configuration and error handling across memory modules - Consolidate LLM and embedding model validation into validators - Improve code organization and reduce duplication in memory storage services - Enhance type classification and verification tools with better error handling
This commit is contained in:
@@ -1,36 +1,28 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
from app.db import get_db
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.models import ModelApiKey, Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from typing import List, Optional
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services import task_service, workspace_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_agent_service = MemoryAgentService()
|
||||
|
||||
router = APIRouter(
|
||||
@@ -39,95 +31,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def validate_config_id(config_id: int, db: Session) -> int:
|
||||
"""
|
||||
Validate and ensure config_id is available, valid, and exists in database.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID to validate
|
||||
db: Database session for checking existence
|
||||
|
||||
Returns:
|
||||
int: Validated config_id
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||
"""
|
||||
if config_id is None:
|
||||
api_logger.info("config_id is required but was not provided")
|
||||
config_id = os.getenv('config_id')
|
||||
if config_id is None:
|
||||
raise ValueError("config_id is required but was not provided")
|
||||
|
||||
|
||||
# Check if config exists in database
|
||||
try:
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.models_model import ModelConfig
|
||||
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if config is None:
|
||||
error_msg = f"Configuration with config_id={config_id} does not exist in database"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate llm_id exists and is usable
|
||||
if config.llm_id:
|
||||
try:
|
||||
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
|
||||
if llm_config is None:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not llm_config.is_active:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating LLM model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no llm_id set")
|
||||
raise ValueError(f"Config {config_id} has no llm_id set")
|
||||
|
||||
# Validate embedding_id exists and is usable
|
||||
if config.embedding_id:
|
||||
try:
|
||||
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
|
||||
if embedding_config is None:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not embedding_config.is_active:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating embedding model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no embedding_id set")
|
||||
raise ValueError(f"Config {config_id} has no embedding_id set")
|
||||
|
||||
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
|
||||
return config_id
|
||||
except ValueError:
|
||||
# Re-raise ValueError from above
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -225,12 +128,7 @@ async def write_server(
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -270,8 +168,14 @@ async def write_server(
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="写入成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write operation error: {str(e)}")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@@ -292,12 +196,7 @@ async def write_server_async(
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -352,12 +251,7 @@ async def read_server(
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
@@ -390,8 +284,14 @@ async def read_server(
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read operation error: {str(e)}")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
|
||||
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||
|
||||
|
||||
@@ -456,12 +356,7 @@ async def read_server_async(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_storage_service import (
|
||||
MemoryStorageService,
|
||||
DataConfigService,
|
||||
kb_type_distribution,
|
||||
search_dialogue,
|
||||
search_chunk,
|
||||
search_statement,
|
||||
search_entity,
|
||||
search_all,
|
||||
search_detials,
|
||||
search_edges,
|
||||
search_entity_graph,
|
||||
MemoryStorageService,
|
||||
analytics_hot_memory_tags,
|
||||
analytics_memory_insight_report,
|
||||
analytics_recent_activity_stats,
|
||||
analytics_user_summary,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_chunk,
|
||||
search_detials,
|
||||
search_dialogue,
|
||||
search_edges,
|
||||
search_entity,
|
||||
search_entity_graph,
|
||||
search_statement,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
ConfigPilotRun,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -329,8 +329,10 @@ async def pilot_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
@@ -338,8 +340,8 @@ async def pilot_run(
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -528,8 +530,8 @@ async def get_user_summary_api(
|
||||
except Exception as e:
|
||||
api_logger.error(f"User summary failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
|
||||
|
||||
@router.get("/self_reflexion")
|
||||
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
|
||||
@@ -326,7 +326,7 @@ def log_prompt_rendering(prompt_type: str, content: str) -> None:
|
||||
logger.info(log_message)
|
||||
|
||||
|
||||
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
|
||||
def log_template_rendering(template_name: str, context: Optional[dict] = None) -> None:
|
||||
"""Log template rendering information.
|
||||
|
||||
Logs the template name and context keys for debugging template rendering.
|
||||
@@ -575,6 +575,43 @@ def get_named_logger(name: str) -> logging.Logger:
|
||||
return get_agent_logger(name)
|
||||
|
||||
|
||||
def get_config_logger() -> logging.Logger:
|
||||
"""Get a specialized logger for memory configuration operations.
|
||||
|
||||
Returns a logger configured specifically for configuration loading, validation,
|
||||
and model resolution operations with:
|
||||
- Logger name: memory.config
|
||||
- Output: Inherits from root logger (console + file)
|
||||
- Level: Inherits from root logger
|
||||
- Format: Standard format with timing information
|
||||
|
||||
This logger is optimized for configuration operations and includes
|
||||
structured logging for timing, validation steps, and error context.
|
||||
|
||||
Returns:
|
||||
Logger configured for memory configuration operations
|
||||
|
||||
Example:
|
||||
>>> logger = get_config_logger()
|
||||
>>> logger.info("Loading configuration", extra={
|
||||
... "config_id": 123,
|
||||
... "workspace_id": "uuid-here",
|
||||
... "operation": "load_config"
|
||||
... })
|
||||
"""
|
||||
# Ensure memory logging is initialized
|
||||
if not LoggingConfig._memory_loggers_initialized:
|
||||
LoggingConfig.setup_memory_logging()
|
||||
|
||||
# Get configuration logger with memory namespace
|
||||
logger = logging.getLogger("memory.config")
|
||||
|
||||
# The logger automatically inherits handlers, formatters, and level from root logger
|
||||
# through Python's logging hierarchy, so no additional configuration is needed
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
"""Get a standard logger for memory module components.
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +25,8 @@ async def create_input_message(
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
multimodal_processor: MultimodalProcessor
|
||||
multimodal_processor: MultimodalProcessor,
|
||||
memory_config: MemoryConfig,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial tool call message from user input.
|
||||
@@ -46,6 +47,7 @@ async def create_input_message(
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
multimodal_processor: Processor for handling image/audio inputs
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
State update with AIMessage containing tool_call
|
||||
@@ -53,7 +55,7 @@ async def create_input_message(
|
||||
Examples:
|
||||
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||
>>> result = await create_input_message(
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
|
||||
... )
|
||||
>>> result["messages"][0].tool_calls[0]["name"]
|
||||
'Split_The_Problem'
|
||||
@@ -123,20 +125,24 @@ async def create_input_message(
|
||||
f"with ID: {tool_call_id}"
|
||||
)
|
||||
|
||||
# Build tool arguments
|
||||
tool_args = {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": tool_name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
},
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}]
|
||||
)
|
||||
|
||||
@@ -9,14 +9,14 @@ import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_content_payload,
|
||||
extract_tool_call_id,
|
||||
extract_content_payload
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,8 +38,9 @@ class ToolExecutionNode:
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: Callable,
|
||||
@@ -49,8 +50,9 @@ class ToolExecutionNode:
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
parameter_builder: ParameterBuilder,
|
||||
storage_type:str,
|
||||
user_rag_memory_id:str
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
):
|
||||
"""
|
||||
Initialize the tool execution node.
|
||||
@@ -63,6 +65,9 @@ class ToolExecutionNode:
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
storage_type: Storage type for the workspace
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = node_id
|
||||
@@ -72,9 +77,10 @@ class ToolExecutionNode:
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.parameter_builder = parameter_builder
|
||||
self.storage_type=storage_type
|
||||
self.user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
self.memory_config = memory_config
|
||||
|
||||
logger.info(
|
||||
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||
)
|
||||
@@ -124,8 +130,12 @@ class ToolExecutionNode:
|
||||
# Extract content payload using state extractors
|
||||
content = extract_content_payload(last_message)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
|
||||
)
|
||||
# Log raw message content for debugging
|
||||
if hasattr(last_message, 'content'):
|
||||
raw = last_message.content
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -143,8 +153,9 @@ class ToolExecutionNode:
|
||||
search_switch=self.search_switch,
|
||||
apply_id=self.apply_id,
|
||||
group_id=self.group_id,
|
||||
memory_config=self.memory_config,
|
||||
storage_type=self.storage_type,
|
||||
user_rag_memory_id=self.user_rag_memory_id
|
||||
user_rag_memory_id=self.user_rag_memory_id,
|
||||
)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||
@@ -179,7 +190,29 @@ class ToolExecutionNode:
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||
)
|
||||
|
||||
# Return the result directly - it already contains the messages list
|
||||
# Check for error in tool response
|
||||
error_entry = None
|
||||
if result and "messages" in result:
|
||||
for msg in result["messages"]:
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
import json
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and "error" in parsed:
|
||||
error_msg = parsed["error"]
|
||||
logger.warning(
|
||||
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
|
||||
)
|
||||
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Return result with error tracking if error was found
|
||||
if error_entry:
|
||||
result["errors"] = [error_entry]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -187,13 +220,15 @@ class ToolExecutionNode:
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return error as ToolMessage to maintain message chain consistency
|
||||
# Track error in state and return error message
|
||||
from langchain_core.messages import ToolMessage
|
||||
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
|
||||
return {
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Error executing tool: {str(e)}",
|
||||
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||
)
|
||||
]
|
||||
],
|
||||
"errors": [error_entry]
|
||||
}
|
||||
|
||||
@@ -1,38 +1,26 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from functools import partial
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
# Import new modular components
|
||||
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue
|
||||
from app.core.memory.agent.langgraph_graph.nodes import (
|
||||
ToolExecutionNode,
|
||||
create_input_message,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -44,9 +32,9 @@ redisdb=os.getenv('REDISDB')
|
||||
redispassword=os.getenv('REDISPASSWORD')
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
# 在工作流中添加循环计数更新
|
||||
# Update loop count in workflow
|
||||
async def update_loop_count(state):
|
||||
"""更新循环计数器"""
|
||||
"""Update loop counter"""
|
||||
current_count = state.get("loop_count", 0)
|
||||
return {"loop_count": current_count + 1}
|
||||
|
||||
@@ -54,13 +42,13 @@ async def update_loop_count(state):
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
messages = state["messages"]
|
||||
|
||||
# 添加边界检查
|
||||
# Add boundary check
|
||||
if not messages:
|
||||
return END
|
||||
counter.add(1) # 累加 1
|
||||
counter.add(1) # Increment by 1
|
||||
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
|
||||
logger.debug(f"[should_continue] Current loop count: {loop_count}")
|
||||
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
@@ -71,15 +59,15 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # 最大循环次数 3
|
||||
if loop_count < 2: # Maximum loop count is 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# 添加默认返回值,避免返回 None
|
||||
# Add default return value to avoid returning None
|
||||
counter.reset()
|
||||
return "Summary" # 或根据业务需求选择合适的默认值
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
@@ -115,8 +103,8 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# 添加默认返回值,避免返回 None
|
||||
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
|
||||
# Add default return value to avoid returning None
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
@@ -151,46 +139,7 @@ def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
# 在 input_sentence 函数中修改参数名称
|
||||
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1].content if messages else ""
|
||||
|
||||
if last_message.endswith('.jpg') or last_message.endswith('.png'):
|
||||
last_message=await picture_model_requests(last_message)
|
||||
if any(last_message.endswith(ext) for ext in audio_extensions):
|
||||
last_message=await Vico_recognition([last_message]).run()
|
||||
logger.debug(f"Audio recognition result: {last_message}")
|
||||
|
||||
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
namespace = str(id).split('_id_')[1]
|
||||
if 'verified_data' in str(last_message):
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": name,
|
||||
"args": {
|
||||
"sentence": last_message,
|
||||
'sessionid': id,
|
||||
'messages_id': str(uuid_str),
|
||||
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
|
||||
"apply_id":apply_id,
|
||||
"group_id":group_id
|
||||
},
|
||||
"id": id + f'_{uuid_str}'
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
return 'Split_The_Problem' # Default case
|
||||
|
||||
|
||||
class ProblemExtensionNode:
|
||||
@@ -208,30 +157,28 @@ class ProblemExtensionNode:
|
||||
async def __call__(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name=='Input_Summary':
|
||||
tool_call =re.findall("'id': '(.*?)'",str(last_message))[0]
|
||||
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
# try:
|
||||
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
|
||||
# except:
|
||||
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name == 'Input_Summary':
|
||||
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
|
||||
else:
|
||||
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
|
||||
# Try to extract actual content payload from previous tool result
|
||||
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
extracted_payload = None
|
||||
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
|
||||
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
|
||||
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||
if m:
|
||||
extracted_payload = m.group(1)
|
||||
else:
|
||||
# 回退:直接尝试使用原始字符串
|
||||
# Fallback: use raw string directly
|
||||
extracted_payload = raw_msg
|
||||
|
||||
# 优先尝试将内容解析为 JSON
|
||||
# Try to parse content as JSON first
|
||||
try:
|
||||
content = json.loads(extracted_payload)
|
||||
except Exception:
|
||||
# 尝试从文本中提取 JSON 片段再解析
|
||||
# Try to extract JSON fragment from text and parse
|
||||
parsed = None
|
||||
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||
for cand in candidates:
|
||||
@@ -240,14 +187,14 @@ class ProblemExtensionNode:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# 如果仍然失败,则以原始字符串作为内容
|
||||
# If still fails, use raw string as content
|
||||
content = parsed if parsed is not None else extracted_payload
|
||||
|
||||
# 根据工具名称构建正确的参数
|
||||
# Build correct parameters based on tool name
|
||||
tool_args = {}
|
||||
|
||||
if self.tool_name == "Verify":
|
||||
# Verify工具需要context和usermessages参数
|
||||
# Verify tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
@@ -256,7 +203,7 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Retrieve":
|
||||
# Retrieve工具需要context和usermessages参数
|
||||
# Retrieve tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
@@ -266,9 +213,9 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
# Summary tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
@@ -276,24 +223,24 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary_fails":
|
||||
# Summary工具需要字符串类型的context参数
|
||||
# Summary_fails tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# 将字典转换为JSON字符串
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name=='Input_Summary':
|
||||
tool_args["context"] =str(last_message)
|
||||
elif self.tool_name == 'Input_Summary':
|
||||
tool_args["context"] = str(last_message)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||
elif self.tool_name=='Retrieve_Summary' :
|
||||
elif self.tool_name == 'Retrieve_Summary':
|
||||
# Retrieve_Summary expects dict directly, not JSON string
|
||||
# content might be a JSON string, try to parse it
|
||||
if isinstance(content, str):
|
||||
@@ -320,7 +267,7 @@ class ProblemExtensionNode:
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
else:
|
||||
# 其他工具使用context参数
|
||||
# Other tools use context parameter
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
@@ -349,12 +296,24 @@ class ProblemExtensionNode:
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
|
||||
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
|
||||
"""
|
||||
Create a read graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
namespace: Namespace identifier
|
||||
tools: MCP tools loaded from session
|
||||
search_switch: Search mode switch ("0", "1", or "2")
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type (optional)
|
||||
user_rag_memory_id: User RAG memory ID (optional)
|
||||
"""
|
||||
memory = InMemorySaver()
|
||||
tool=[i.name for i in tools ]
|
||||
tool = [i.name for i in tools]
|
||||
logger.info(f"Initializing read graph with tools: {tool}")
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
# Extract tool functions
|
||||
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||
@@ -382,9 +341,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Retrieve_node = ToolExecutionNode(
|
||||
tool=Retrieve_,
|
||||
node_id="Retrieve_id",
|
||||
@@ -394,9 +354,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Verify_node = ToolExecutionNode(
|
||||
tool=Verify_,
|
||||
node_id="Verify_id",
|
||||
@@ -406,7 +367,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
Summary_node = ToolExecutionNode(
|
||||
@@ -418,9 +380,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Summary_fails_node = ToolExecutionNode(
|
||||
tool=Summary_fails_,
|
||||
node_id="Summary_fails_id",
|
||||
@@ -430,9 +393,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Retrieve_Summary_node = ToolExecutionNode(
|
||||
tool=Retrieve_Summary_,
|
||||
node_id="Retrieve_Summary_id",
|
||||
@@ -442,9 +406,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
Input_Summary_node = ToolExecutionNode(
|
||||
tool=Input_Summary_,
|
||||
node_id="Input_Summary_id",
|
||||
@@ -454,16 +419,16 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
async def content_input_node(state):
|
||||
state_search_switch = state.get("search_switch", search_switch)
|
||||
|
||||
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||
|
||||
|
||||
return await create_input_message(
|
||||
state=state,
|
||||
tool_name=tool_name,
|
||||
@@ -471,7 +436,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
multimodal_processor=multimodal_processor
|
||||
multimodal_processor=multimodal_processor,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -501,8 +467,3 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
|
||||
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
yield graph
|
||||
|
||||
|
||||
# 添加到文件末尾或创建新的执行脚本
|
||||
# 在 memory_agent_service.py 文件中添加以下函数
|
||||
|
||||
|
||||
@@ -128,6 +128,15 @@ def extract_content_payload(message: Any) -> Any:
|
||||
# For ToolMessages (responses from tools), extract from content
|
||||
if hasattr(message, "content"):
|
||||
raw_content = message.content
|
||||
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(raw_content, list):
|
||||
for block in raw_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
raw_content = block.get('text', '')
|
||||
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
|
||||
break
|
||||
|
||||
# If content is empty and this is an AIMessage with tool_calls,
|
||||
# extract from args (this handles the initial tool call from content_input)
|
||||
@@ -140,13 +149,16 @@ def extract_content_payload(message: Any) -> Any:
|
||||
|
||||
# If content is already a dict or list, return it directly
|
||||
if isinstance(raw_content, (dict, list)):
|
||||
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
|
||||
return raw_content
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(raw_content, str):
|
||||
# First, try direct JSON parsing
|
||||
try:
|
||||
return json.loads(raw_content)
|
||||
parsed = json.loads(raw_content)
|
||||
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
@@ -156,9 +168,12 @@ def extract_content_payload(message: Any) -> Any:
|
||||
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||
for candidate in json_candidates:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
parsed = json.loads(candidate)
|
||||
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail, return the raw content
|
||||
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
|
||||
return raw_content
|
||||
|
||||
@@ -1,69 +1,71 @@
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import add_messages, StateGraph
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
import warnings
|
||||
import sys
|
||||
from langchain_core.messages import AIMessage
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
|
||||
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
|
||||
if config_id:
|
||||
logger.info(f"使用配置 ID: {config_id}")
|
||||
|
||||
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
logger.info("Loading MCP tools: %s", [t.name for t in tools])
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||
|
||||
if not data_type_tool or not data_write_tool:
|
||||
logger.error('不存在数据存储工具', exc_info=True)
|
||||
raise ValueError('不存在数据存储工具')
|
||||
# ToolNode
|
||||
write_node = ToolNode([data_write_tool])
|
||||
if not data_write_tool:
|
||||
logger.error("Data_write tool not found", exc_info=True)
|
||||
raise ValueError("Data_write tool not found")
|
||||
|
||||
write_node = ToolNode([data_write_tool])
|
||||
|
||||
async def call_model(state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
|
||||
result = await data_type_tool.ainvoke({
|
||||
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
})
|
||||
result=json.loads( result)
|
||||
|
||||
# 调用 Data_write,传递 config_id
|
||||
# Call Data_write directly with memory_config
|
||||
write_params = {
|
||||
"content": result["context"],
|
||||
"content": content,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id
|
||||
"user_id": user_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
# 如果提供了 config_id,添加到参数中
|
||||
if config_id:
|
||||
write_params["config_id"] = config_id
|
||||
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
|
||||
|
||||
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
|
||||
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
content = write_result.get("data", str(write_result))
|
||||
result_content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
content = str(write_result)
|
||||
logger.info("写入内容: %s", content)
|
||||
return {"messages": [AIMessage(content=content)]}
|
||||
result_content = str(write_result)
|
||||
logger.info("Write content: %s", result_content)
|
||||
return {"messages": [AIMessage(content=result_content)]}
|
||||
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", call_model)
|
||||
|
||||
@@ -10,19 +10,19 @@ Package structure:
|
||||
- models: Pydantic response models
|
||||
- services: Business logic services
|
||||
"""
|
||||
from app.core.memory.agent.mcp_server.server import (
|
||||
mcp,
|
||||
initialize_context,
|
||||
main,
|
||||
get_context_resource
|
||||
)
|
||||
# from app.core.memory.agent.mcp_server.server import (
|
||||
# mcp,
|
||||
# initialize_context,
|
||||
# main,
|
||||
# get_context_resource
|
||||
# )
|
||||
|
||||
# Import tools to register them (but don't export them)
|
||||
from app.core.memory.agent.mcp_server import tools
|
||||
# # Import tools to register them (but don't export them)
|
||||
# from app.core.memory.agent.mcp_server import tools
|
||||
|
||||
__all__ = [
|
||||
'mcp',
|
||||
'initialize_context',
|
||||
'main',
|
||||
'get_context_resource',
|
||||
]
|
||||
# __all__ = [
|
||||
# 'mcp',
|
||||
# 'initialize_context',
|
||||
# 'main',
|
||||
# 'get_context_resource',
|
||||
# ]
|
||||
@@ -6,19 +6,15 @@ in the context for dependency injection into tool functions.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -78,17 +74,11 @@ def initialize_context():
|
||||
logger.info("Registering session_store in context")
|
||||
mcp.session_store = store
|
||||
|
||||
# Register LLM client
|
||||
try:
|
||||
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
mcp.llm_client = llm_client
|
||||
logger.info("llm_client registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
|
||||
# 注册一个 None 值,避免工具调用时找不到资源
|
||||
mcp.llm_client = None
|
||||
logger.warning("llm_client set to None due to initialization failure")
|
||||
# Note: LLM client is NOT loaded at server startup
|
||||
# It should be loaded dynamically when needed, with config_id passed explicitly
|
||||
# to make_write_graph or make_read_graph functions
|
||||
logger.info("LLM client will be loaded dynamically with config_id when needed")
|
||||
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
|
||||
|
||||
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||
logger.info("Registering app_settings in context")
|
||||
@@ -124,26 +114,20 @@ def main():
|
||||
Initializes context and starts the server with SSE transport.
|
||||
"""
|
||||
try:
|
||||
# logger.info("Starting MCP server initialization")
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
logger.info("Starting MCP server initialization")
|
||||
# Initialize context resources
|
||||
initialize_context()
|
||||
|
||||
# Import and register tools
|
||||
# logger.info("Importing MCP tools")
|
||||
from app.core.memory.agent.mcp_server.tools import (
|
||||
# Import and register tools (imports trigger tool registration)
|
||||
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
|
||||
data_tools,
|
||||
problem_tools,
|
||||
retrieval_tools,
|
||||
verification_tools,
|
||||
summary_tools,
|
||||
data_tools
|
||||
verification_tools,
|
||||
)
|
||||
# logger.info("All MCP tools imported and registered")
|
||||
|
||||
# Log registered tools for debugging
|
||||
import asyncio
|
||||
tools_list = asyncio.run(mcp.list_tools())
|
||||
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
|
||||
# Tools are registered via imports above
|
||||
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
||||
|
||||
@@ -4,22 +4,22 @@ Parameter Builder for constructing tool call arguments.
|
||||
This service provides tool-specific parameter transformation logic
|
||||
to build correct arguments for each tool type.
|
||||
"""
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ParameterBuilder:
|
||||
"""Service for building tool call arguments based on tool type."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the parameter builder."""
|
||||
logger.info("ParameterBuilder initialized")
|
||||
|
||||
|
||||
def build_tool_args(
|
||||
self,
|
||||
tool_name: str,
|
||||
@@ -28,8 +28,9 @@ class ParameterBuilder:
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tool arguments based on tool type.
|
||||
@@ -48,6 +49,7 @@ class ParameterBuilder:
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -58,18 +60,19 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
base_args["storage_type"] = storage_type if storage_type is not None else ""
|
||||
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||
|
||||
# Tool-specific argument construction
|
||||
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
|
||||
# Verify expects dict context
|
||||
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
|
||||
# These tools expect dict context
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
"context": content if isinstance(content, dict) else {"content": content},
|
||||
**base_args
|
||||
}
|
||||
|
||||
|
||||
@@ -4,21 +4,31 @@ Search Service for executing hybrid search and processing results.
|
||||
This service provides clean search result processing with content extraction
|
||||
and deduplication.
|
||||
"""
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
|
||||
def __init__(self, memory_config: "MemoryConfig" = None):
|
||||
"""
|
||||
Initialize the search service.
|
||||
|
||||
Args:
|
||||
memory_config: Optional MemoryConfig for embedding model configuration.
|
||||
If not provided, must be passed to execute_hybrid_search.
|
||||
"""
|
||||
self.memory_config = memory_config
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
@@ -93,12 +103,13 @@ class SearchService:
|
||||
self,
|
||||
group_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
limit: int = 15,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False
|
||||
return_raw_results: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -112,6 +123,7 @@ class SearchService:
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
@@ -119,12 +131,17 @@ class SearchService:
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
|
||||
# Use provided memory_config or fall back to instance config
|
||||
config = memory_config or self.memory_config
|
||||
if not config:
|
||||
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
# Execute search using embedding_model_id from memory_config
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
@@ -132,7 +149,8 @@ class SearchService:
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
rerank_alpha=rerank_alpha
|
||||
embedding_id=str(config.embedding_model_id),
|
||||
rerank_alpha=rerank_alpha,
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
|
||||
@@ -3,16 +3,19 @@ Data Tools for data type differentiation and writing.
|
||||
|
||||
This module contains MCP tools for distinguishing data types and writing data.
|
||||
"""
|
||||
import os
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import (
|
||||
DistinguishTypeResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -20,7 +23,8 @@ logger = get_agent_logger(__name__)
|
||||
@mcp.tool()
|
||||
async def Data_type_differentiation(
|
||||
ctx: Context,
|
||||
context: str
|
||||
context: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Distinguish the type of data (read or write).
|
||||
@@ -28,6 +32,7 @@ async def Data_type_differentiation(
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Text to analyze for type differentiation
|
||||
memory_config: MemoryConfig object containing LLM configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with the original text and 'type' field
|
||||
@@ -35,7 +40,9 @@ async def Data_type_differentiation(
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
|
||||
# Render template
|
||||
try:
|
||||
@@ -53,7 +60,7 @@ async def Data_type_differentiation(
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
@@ -98,7 +105,7 @@ async def Data_write(
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
config_id: str
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
@@ -109,7 +116,7 @@ async def Data_write(
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
config_id: Configuration ID for processing (optional, integer)
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
@@ -118,32 +125,28 @@ async def Data_write(
|
||||
# Ensure output directory exists
|
||||
os.makedirs("data_output", exist_ok=True)
|
||||
file_path = os.path.join("data_output", "user_data.csv")
|
||||
|
||||
# Write data using utility function
|
||||
try:
|
||||
await write(content, user_id, apply_id, group_id, config_id=config_id)
|
||||
logger.info(f"写入成功!Config ID: {config_id if config_id else 'None'}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": config_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"写入失败: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_write failed: {e}",
|
||||
exc_info=True
|
||||
|
||||
# Write data - clients are constructed inside write() from memory_config
|
||||
await write(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
"message": str(e),
|
||||
}
|
||||
|
||||
@@ -2,25 +2,23 @@
|
||||
Problem Tools for question segmentation and extension.
|
||||
|
||||
This module contains MCP tools for breaking down and extending user questions.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
ProblemBreakdownItem,
|
||||
ProblemBreakdownResponse,
|
||||
ExtendedQuestionItem,
|
||||
ProblemExtensionResponse
|
||||
ProblemExtensionResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -32,7 +30,8 @@ async def Split_The_Problem(
|
||||
sessionid: str,
|
||||
messages_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Segment the dialogue or sentence into sub-problems.
|
||||
@@ -44,17 +43,20 @@ async def Split_The_Problem(
|
||||
messages_id: Message identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
@@ -116,8 +118,8 @@ async def Split_The_Problem(
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info("问题拆分")
|
||||
logger.info(f"问题拆分结果==>>:{split_result}")
|
||||
logger.info("Problem splitting")
|
||||
logger.info(f"Problem split result: {split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
@@ -150,7 +152,7 @@ async def Split_The_Problem(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题拆分', duration)
|
||||
log_time('Problem splitting', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -160,8 +162,9 @@ async def Problem_Extension(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Extend the problem with additional sub-questions.
|
||||
@@ -172,6 +175,7 @@ async def Problem_Extension(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -179,12 +183,14 @@ async def Problem_Extension(
|
||||
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
@@ -250,8 +256,8 @@ async def Problem_Extension(
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info("问题扩展")
|
||||
logger.info(f"问题扩展==>>:{aggregated_dict}")
|
||||
logger.info("Problem extension")
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
@@ -290,4 +296,4 @@ async def Problem_Extension(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('问题扩展', duration)
|
||||
log_time('Problem extension', duration)
|
||||
|
||||
@@ -3,25 +3,24 @@ Retrieval Tools for database and context retrieval.
|
||||
|
||||
This module contains MCP tools for retrieving data using hybrid search.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import os
|
||||
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@@ -32,8 +31,9 @@ async def Retrieve(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve data from the database using hybrid search.
|
||||
@@ -44,6 +44,7 @@ async def Retrieve(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
@@ -66,6 +67,7 @@ async def Retrieve(
|
||||
}
|
||||
start = time.time()
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
@@ -77,7 +79,13 @@ async def Retrieve(
|
||||
if isinstance(context, dict):
|
||||
# Process dict context with extended questions
|
||||
all_items = []
|
||||
logger.info(f"Retrieve: context keys={list(context.keys())}")
|
||||
content, original = await Retriev_messages_deal(context)
|
||||
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
|
||||
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
|
||||
|
||||
if not original:
|
||||
logger.warning(f"Retrieve: original query is empty! context={context}")
|
||||
|
||||
# Extract all query items from content
|
||||
# content is like {original_question: [extended_questions...], ...}
|
||||
@@ -113,9 +121,11 @@ async def Retrieve(
|
||||
clean_content = ''
|
||||
raw_results=''
|
||||
cleaned_query = question
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
databases_anser.append({
|
||||
"Query_small": cleaned_query,
|
||||
@@ -206,9 +216,11 @@ async def Retrieve(
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = query
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||
dup_databases = {
|
||||
"Query": cleaned_query,
|
||||
@@ -236,7 +248,7 @@ async def Retrieve(
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||
)
|
||||
|
||||
@@ -279,4 +291,4 @@ async def Retrieve(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
log_time('Retrieval', duration)
|
||||
|
||||
@@ -2,33 +2,31 @@
|
||||
Summary Tools for data summarization.
|
||||
|
||||
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||
SummaryData,
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
RetrieveSummaryData,
|
||||
RetrieveSummaryResponse
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Summary_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -40,8 +38,9 @@ async def Summary(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize the verified data.
|
||||
@@ -52,6 +51,7 @@ async def Summary(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -59,12 +59,14 @@ async def Summary(
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -155,7 +157,7 @@ async def Summary(
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"验证之后的总结==>>:{aimessages}")
|
||||
logger.info(f"Summary after verification: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
@@ -163,7 +165,7 @@ async def Summary(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('总结', duration)
|
||||
log_time('Summary', duration)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -180,8 +182,9 @@ async def Retrieve_Summary(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
@@ -192,6 +195,7 @@ async def Retrieve_Summary(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -202,9 +206,11 @@ async def Retrieve_Summary(
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
@@ -212,6 +218,8 @@ async def Retrieve_Summary(
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
inner = context["content"]
|
||||
@@ -252,17 +260,19 @@ async def Retrieve_Summary(
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
|
||||
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
|
||||
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
|
||||
answer = None
|
||||
if isinstance(item, dict):
|
||||
if "Answer_Small" in item:
|
||||
answer = item["Answer_Small"]
|
||||
elif "Answer_Samll" in item:
|
||||
answer = item["Answer_Samll"]
|
||||
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
@@ -350,7 +360,7 @@ async def Retrieve_Summary(
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"检索之后的总结==>>:{aimessages}")
|
||||
logger.info(f"Summary after retrieval: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
@@ -358,7 +368,7 @@ async def Retrieve_Summary(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索总结', duration)
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
@@ -384,8 +394,9 @@ async def Input_Summary(
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
@@ -397,6 +408,7 @@ async def Input_Summary(
|
||||
search_switch: Search switch value for routing ('2' for summaries only)
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
@@ -406,21 +418,14 @@ async def Input_Summary(
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# Initialize variables to avoid UnboundLocalError
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Check if llm_client is None
|
||||
if llm_client is None:
|
||||
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
# Get LLM client from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
@@ -479,7 +484,7 @@ async def Input_Summary(
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
'''检索'''
|
||||
# Retrieval
|
||||
if search_switch == '2':
|
||||
search_params["include"] = ["summaries"]
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
@@ -509,12 +514,16 @@ async def Input_Summary(
|
||||
except:
|
||||
retrieve_info=''
|
||||
raw_results=['']
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
logger.info("Input_Summary: 使用 summary 进行检索")
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
logger.info("Input_Summary: Using summary for retrieval")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -547,7 +556,7 @@ async def Input_Summary(
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
@@ -587,7 +596,7 @@ async def Input_Summary(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
log_time('Retrieval', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
||||
@@ -5,20 +5,19 @@ This module contains MCP tools for verifying retrieved data.
|
||||
"""
|
||||
import time
|
||||
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Verify_messages_deal,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Resolve_username
|
||||
)
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Verify_messages_deal,
|
||||
)
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -30,6 +29,7 @@ async def Verify(
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
@@ -42,6 +42,7 @@ async def Verify(
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
@@ -91,8 +92,12 @@ async def Verify(
|
||||
|
||||
|
||||
|
||||
# Call verification workflow
|
||||
verify_tool = VerifyTool(system_prompt, messages)
|
||||
# Call verification workflow with LLM model ID from memory_config
|
||||
verify_tool = VerifyTool(
|
||||
system_prompt=system_prompt,
|
||||
verify_data=messages,
|
||||
llm_model_id=str(memory_config.llm_model_id)
|
||||
)
|
||||
verify_result = await verify_tool.verify()
|
||||
|
||||
# Parse LLM verification result with error handling
|
||||
@@ -118,7 +123,7 @@ async def Verify(
|
||||
"history": history,
|
||||
}
|
||||
|
||||
logger.info(f"验证==>>:{messages_deal}")
|
||||
logger.info(f"Verification result: {messages_deal}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
@@ -128,7 +133,7 @@ async def Verify(
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "数据验证",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.get("split_result", "unknown"),
|
||||
"reason": messages_deal.get("reason", ""),
|
||||
"query": query,
|
||||
@@ -166,4 +171,4 @@ async def Verify(
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('验证', duration)
|
||||
log_time('Verification', duration)
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TypedDict, Annotated
|
||||
import os
|
||||
import logging
|
||||
|
||||
from jinja2 import Template
|
||||
from langchain_core.messages import AnyMessage
|
||||
from dotenv import load_dotenv
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_picture_config,
|
||||
get_voice_config,
|
||||
)
|
||||
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,6 +43,7 @@ class WriteState(TypedDict):
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
|
||||
class ReadState(TypedDict):
|
||||
'''
|
||||
@@ -53,6 +53,7 @@ class ReadState(TypedDict):
|
||||
loop_count:Traverse times
|
||||
search_switch:type
|
||||
config_id: configuration id for filtering results
|
||||
errors: list of errors that occurred during workflow execution
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
|
||||
name: str
|
||||
@@ -63,6 +64,7 @@ class ReadState(TypedDict):
|
||||
apply_id: str
|
||||
group_id: str
|
||||
config_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
|
||||
|
||||
class COUNTState:
|
||||
@@ -109,9 +111,17 @@ def deduplicate_entries(entries):
|
||||
|
||||
|
||||
|
||||
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
PROMPT_TICKET_EXTRACTION: Extraction prompt
|
||||
picture_model_name: Picture model name (required, no longer from global variables)
|
||||
"""
|
||||
try:
|
||||
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
|
||||
model_config = get_picture_config(picture_model_name)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
@@ -147,9 +157,15 @@ async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||
picture_text = json.loads(picture_text)
|
||||
return (picture_text['statement'])
|
||||
|
||||
async def Voice_recognize():
|
||||
async def Voice_recognize(voice_model_name: str):
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
voice_model_name: Voice model name (required, no longer from global variables)
|
||||
"""
|
||||
try:
|
||||
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
|
||||
model_config = get_voice_config(voice_model_name)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Any
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -119,11 +119,23 @@ async def Problem_Extension_messages_deal(context):
|
||||
extent_quest = []
|
||||
original = context.get('original', '')
|
||||
messages = context.get('context', '')
|
||||
messages = json.loads(messages)
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
# Handle empty or non-string messages
|
||||
if not messages:
|
||||
return extent_quest, original
|
||||
|
||||
if isinstance(messages, str):
|
||||
try:
|
||||
messages = json.loads(messages)
|
||||
except json.JSONDecodeError:
|
||||
# If JSON parsing fails, return empty list
|
||||
return extent_quest, original
|
||||
|
||||
if isinstance(messages, list):
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
return extent_quest, original
|
||||
|
||||
@@ -135,10 +147,19 @@ async def Retriev_messages_deal(context):
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}")
|
||||
if 'context' in context or 'original' in context:
|
||||
return context.get('context', {}), context.get('original', '')
|
||||
return content, original_value
|
||||
content = context.get('context', {})
|
||||
original = context.get('original', '')
|
||||
logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'")
|
||||
return content, original
|
||||
|
||||
# Return empty defaults if context is not a dict or doesn't have expected keys
|
||||
logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults")
|
||||
return {}, ''
|
||||
|
||||
async def Verify_messages_deal(context):
|
||||
'''
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
# 角色
|
||||
你是验证专家
|
||||
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析,是不是回答Query_Samll这个字段的问题
|
||||
你的目标是针对用户的输入Query_Small字段的提问和Answer_Small的回答分析,是不是回答Query_Small这个字段的问题
|
||||
|
||||
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
|
||||
## 工作步骤
|
||||
1. 获取所有的Query_Samll字段和Answer_Samll字段
|
||||
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
|
||||
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
|
||||
1. 获取所有的Query_Small字段和Answer_Small字段
|
||||
2. 分析Answer_Small的回复是不是和Query_Small有关系
|
||||
3. 判断Answer_Small和Query_Small之间分析出来的关系状态
|
||||
4. 如果是True保留,否则不要相对应的问题和回答
|
||||
5. 输出,需要严格按照模版
|
||||
输入:{{history}}
|
||||
历史消息:{"history":{{sentence}}}
|
||||
### 第一步 获取用户的输入
|
||||
获取用户的输入提取对应的Query_Samll和Answer_Samll
|
||||
获取用户的输入提取对应的Query_Small和Answer_Small
|
||||
### 第二步 分析验证
|
||||
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||
需要分析Query_Small和Answer_Small之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||
## 核心验证标准
|
||||
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Samll):
|
||||
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Small):
|
||||
1. 合理性标准(必须全部满足):
|
||||
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
|
||||
- 最小化:每个不同的子问题数量应尽可能少,通常不超过原问题关键要素数量的2倍(建议2-4个),避免冗余和不必要拆分。
|
||||
|
||||
@@ -19,12 +19,14 @@ class DistinguishTypeResponse(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
async def status_typle(messages: str) -> dict:
|
||||
async def status_typle(messages: str, llm_model_id: str) -> dict:
|
||||
"""
|
||||
Classify message type as read or write operation.
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
messages: User message to classify
|
||||
llm_model_id: LLM model ID to use (required, no longer from global variables)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'type' field with classification result
|
||||
@@ -42,8 +44,7 @@ async def status_typle(messages: str) -> dict:
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
llm_client = get_llm_client(llm_model_id)
|
||||
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain_core.messages import HumanMessage
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
@@ -31,8 +31,17 @@ class State(TypedDict):
|
||||
|
||||
|
||||
class VerifyTool:
|
||||
def __init__(self, system_prompt: str="", verify_data: Any=None):
|
||||
def __init__(self, system_prompt: str="", verify_data: Any=None, llm_model_id: str=None):
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for verification
|
||||
verify_data: Data to verify
|
||||
llm_model_id: LLM model ID (required, no longer from global variables)
|
||||
"""
|
||||
self.system_prompt = system_prompt
|
||||
self.llm_model_id = llm_model_id
|
||||
if isinstance(verify_data, str):
|
||||
self.verify_data = verify_data
|
||||
else:
|
||||
@@ -42,7 +51,9 @@ class VerifyTool:
|
||||
self.verify_data = str(verify_data)
|
||||
|
||||
async def model_1(self, state: State) -> State:
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
if not self.llm_model_id:
|
||||
raise ValueError("llm_model_id is required but not provided")
|
||||
llm_client = get_llm_client(self.llm_model_id)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
|
||||
)
|
||||
|
||||
@@ -1,80 +1,93 @@
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
"""
|
||||
Write Tools for Memory Knowledge Extraction Pipeline
|
||||
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation_all,
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# 导入配置模块(而不是直接导入变量)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
)
|
||||
from app.core.memory.utils.embedder.embedder_utils import (
|
||||
get_embedder_client_from_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
|
||||
|
||||
async def write(
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator)
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
Only MemoryConfig is needed - LLM and embedding clients are constructed
|
||||
internally from the config.
|
||||
|
||||
Args:
|
||||
content: 对话内容
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
group_id: 组ID
|
||||
ref_id: 参考ID,默认为 "wyl20251027"
|
||||
config_id: 配置ID,用于标记数据处理配置
|
||||
content: Dialogue content to process
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
"""
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
config_id = str(memory_config.config_id)
|
||||
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
|
||||
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
|
||||
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
|
||||
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
|
||||
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
|
||||
logger.info(f"Config ID: {config_id if config_id else 'None'}")
|
||||
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
|
||||
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
|
||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
|
||||
# Construct clients from memory_config
|
||||
llm_client = get_llm_client_from_config(memory_config)
|
||||
embedder_client = get_embedder_client_from_config(memory_config)
|
||||
logger.info("LLM and embedding clients constructed")
|
||||
|
||||
# Initialize timing log
|
||||
log_file = "logs/time.log"
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
||||
f.write(f"Config: {memory_config.config_name} (ID: {config_id})\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
|
||||
# 初始化客户端
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# 获取 embedder 配置
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
# Initialize Neo4j connector
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Step 1: 加载和分块数据
|
||||
|
||||
# Step 1: Load and chunk data
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
@@ -83,21 +96,21 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
config_id=config_id,
|
||||
)
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# Step 2: 初始化并运行 ExtractionOrchestrator
|
||||
|
||||
# Step 2: Initialize and run ExtractionOrchestrator
|
||||
step_start = time.time()
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
pipeline_config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
config=pipeline_config,
|
||||
embedding_id=embedding_model_id,
|
||||
)
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run returns a flat tuple of 7 values after deduplication
|
||||
|
||||
# Run the complete extraction pipeline
|
||||
(
|
||||
all_dialogue_nodes,
|
||||
all_chunk_nodes,
|
||||
@@ -107,14 +120,12 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_dedup_details,
|
||||
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# Step 8: Save all data to Neo4j database using graph models
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
# 运行索引创建
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
@@ -141,18 +152,16 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
|
||||
# Step 4: Generate Memory summaries and save to Neo4j
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
|
||||
)
|
||||
|
||||
# Save memory summaries to Neo4j as nodes
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
# Link summaries to statements via chunks for summary→entity queries
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
try:
|
||||
@@ -162,24 +171,15 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
finally:
|
||||
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
# Add completion marker to log
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
|
||||
asyncio.run(write(content, ref_id="wyl20251027"))
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from neo4j import GraphDatabase
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from neo4j import GraphDatabase
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ------------------- 自包含路径解析 -------------------
|
||||
@@ -31,11 +32,16 @@ except NameError:
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
# 现在路径已经配置好,我们可以使用绝对导入
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
#TODO: Fix this
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
@@ -140,8 +146,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
"""
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# 默认从环境变量读取
|
||||
group_id = group_id or DEFAULT_GROUP_ID
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
@@ -150,8 +156,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
llm_client = get_llm_client(DEFAULT_LLM_ID)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
@@ -165,8 +170,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
|
||||
if __name__ == "__main__":
|
||||
print("开始获取热门记忆标签...")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
group_id_to_query = SELECTED_GROUP_ID
|
||||
# 直接使用环境变量中的 group_id
|
||||
group_id_to_query = DEFAULT_GROUP_ID
|
||||
# 使用 asyncio.run 来执行异步主函数
|
||||
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ This script can be executed directly to generate a memory insight report for a t
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
@@ -17,12 +17,16 @@ src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class TagClassification(BaseModel):
|
||||
@@ -55,8 +59,7 @@ class MemoryInsight:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
self.llm_client = get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
@@ -294,8 +297,8 @@ async def main():
|
||||
"""
|
||||
Initializes and runs the memory insight analysis for a test user.
|
||||
"""
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
test_user_id = SELECTED_GROUP_ID
|
||||
# 默认从环境变量读取
|
||||
test_user_id = DEFAULT_GROUP_ID
|
||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
||||
|
||||
insight = None
|
||||
|
||||
@@ -6,10 +6,10 @@ Usage:
|
||||
python -m analytics.user_summary --user_id <group_id>
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -24,10 +24,15 @@ try:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,8 +47,7 @@ class UserSummary:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
self.llm = get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
@@ -107,8 +111,8 @@ class UserSummary:
|
||||
|
||||
|
||||
async def generate_user_summary(user_id: str | None = None) -> str:
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
effective_group_id = user_id or SELECTED_GROUP_ID
|
||||
# 默认从环境变量读取
|
||||
effective_group_id = user_id or DEFAULT_GROUP_ID
|
||||
svc = UserSummary(effective_group_id)
|
||||
try:
|
||||
return await svc.generate()
|
||||
@@ -139,7 +143,7 @@ if __name__ == "__main__":
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["user_summary"] = {
|
||||
"group_id": SELECTED_GROUP_ID,
|
||||
"group_id": DEFAULT_GROUP_ID,
|
||||
"summary": summary
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
{
|
||||
"llm_list": [
|
||||
{
|
||||
"llm_name": "qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_AGENT_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen3-14b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/deepseek-r1-0528-qwen3-8b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen3-235b-a22b-instruct-2507",
|
||||
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "DASHSCOPE_API_KEY"
|
||||
}
|
||||
,
|
||||
{
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "DASHSCOPE_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
},
|
||||
{
|
||||
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
}
|
||||
],
|
||||
"embedding_list": [
|
||||
{
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"api_base": "http://119.45.239.97:11434/v1",
|
||||
"dimension": 768
|
||||
},
|
||||
{
|
||||
"embedding_name": "openai/bge-m3",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"dimension": 1024
|
||||
}
|
||||
],
|
||||
"neo4j": {
|
||||
"uri": "bolt://1.94.111.67:7687",
|
||||
"username": "neo4j"
|
||||
},
|
||||
"chunker_list": [
|
||||
{
|
||||
"chunker_strategy": "TokenChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 56,
|
||||
"tokenizer_or_token_counter": "character"
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "SemanticChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1024,
|
||||
"threshold": 0.8,
|
||||
"min_sentences": 2,
|
||||
"skip_window": 1,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "LateChunker",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size": 2048,
|
||||
"min_characters_per_chunk": 24
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "NeuralChunker",
|
||||
"embedding_model": "mirth/chonky_modernbert_base_1",
|
||||
"min_characters_per_chunk": 24
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1000,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "HybridChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"threshold": 0.8,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "SentenceChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 2048,
|
||||
"chunk_overlap": 128,
|
||||
"min_sentences_per_chunk": 1,
|
||||
"min_characters_per_sentence": 12,
|
||||
"delim": [".", "!", "?", "\n"],
|
||||
"include_delim": "prev",
|
||||
"tokenizer_or_token_counter": "character"
|
||||
}
|
||||
],
|
||||
"langfuse": {
|
||||
"enabled": true
|
||||
},
|
||||
"agenta": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"selections": {
|
||||
"config_id": ""
|
||||
}
|
||||
}
|
||||
@@ -51,16 +51,31 @@ logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
async def main(
|
||||
# Required configuration parameters (no longer from global variables)
|
||||
chunker_strategy: str,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
llm_model_id: str,
|
||||
embedding_model_id: str,
|
||||
# Optional parameters
|
||||
dialogue_text: Optional[str] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本
|
||||
记忆系统主流程 - 重构版本 (Updated to eliminate global variables)
|
||||
|
||||
该函数是重构后的主入口,使用新的模块化架构。
|
||||
Global variables have been eliminated in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
chunker_strategy: Chunking strategy to use (required)
|
||||
group_id: Group ID for the operation (required)
|
||||
user_id: User ID for the operation (required)
|
||||
apply_id: Application ID for the operation (required)
|
||||
llm_model_id: LLM model ID to use (required)
|
||||
embedding_model_id: Embedding model ID to use (required)
|
||||
dialogue_text: 输入的对话文本(可选,用于试运行模式)
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
@@ -82,12 +97,10 @@ async def main(
|
||||
print("MemSci 知识提取流水线 - 重构版本")
|
||||
print("=" * 60)
|
||||
print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}")
|
||||
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
|
||||
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
|
||||
print("Using model ID:", config_defs.SELECTED_LLM_ID)
|
||||
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
|
||||
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
|
||||
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
|
||||
print("Using chunker strategy:", chunker_strategy)
|
||||
print("Using group ID:", group_id)
|
||||
print("Using model ID:", llm_model_id)
|
||||
print("Using embedding model ID:", embedding_model_id)
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化日志
|
||||
@@ -104,11 +117,11 @@ async def main(
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
llm_client = get_llm_client(llm_model_id)
|
||||
|
||||
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config_dict = get_embedder_config(embedding_model_id)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
@@ -145,9 +158,9 @@ async def main(
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
user_id=config_defs.SELECTED_USER_ID,
|
||||
apply_id=config_defs.SELECTED_APPLY_ID,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
@@ -158,7 +171,7 @@ async def main(
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
chunker_strategy=chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
@@ -172,7 +185,7 @@ async def main(
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
@@ -180,7 +193,7 @@ async def main(
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
else:
|
||||
@@ -199,11 +212,11 @@ async def main(
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
user_id=config_defs.SELECTED_USER_ID,
|
||||
apply_id=config_defs.SELECTED_APPLY_ID,
|
||||
indices=config_defs.SELECTED_TEST_DATA_INDICES,
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
indices=None,
|
||||
input_data_path=test_data_path,
|
||||
llm_client=llm_client,
|
||||
skip_cleaning=True,
|
||||
@@ -219,7 +232,7 @@ async def main(
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
@@ -227,7 +240,7 @@ async def main(
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
"chunker_strategy": chunker_strategy
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
@@ -249,6 +262,7 @@ async def main(
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback, # 传递进度回调
|
||||
embedding_id=embedding_model_id, # 传递嵌入模型ID
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
@@ -352,7 +366,7 @@ async def main(
|
||||
)
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
|
||||
)
|
||||
|
||||
if not is_pilot_run:
|
||||
@@ -400,4 +414,17 @@ async def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
print("⚠️ Warning: This script now requires explicit configuration parameters.")
|
||||
print("Global variables have been removed. Please provide configuration parameters.")
|
||||
print("Example usage:")
|
||||
print(" asyncio.run(main(")
|
||||
print(" chunker_strategy='RecursiveChunker',")
|
||||
print(" group_id='your_group_id',")
|
||||
print(" user_id='your_user_id',")
|
||||
print(" apply_id='your_apply_id',")
|
||||
print(" llm_model_id='your_llm_id',")
|
||||
print(" embedding_model_id='your_embedding_id'")
|
||||
print(" ))")
|
||||
|
||||
# This will fail because global variables are removed
|
||||
raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.")
|
||||
|
||||
@@ -1,31 +1,41 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dotenv import load_dotenv
|
||||
from datetime import datetime
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph_by_embedding, search_graph,
|
||||
search_graph_by_temporal, search_graph_by_keyword_temporal,
|
||||
search_graph_by_chunk_id
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.config_models import TemporalSearchParams
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
|
||||
ForgettingEngine,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_embedder_config,
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_chunk_id,
|
||||
search_graph_by_embedding,
|
||||
search_graph_by_keyword_temporal,
|
||||
search_graph_by_temporal,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
@@ -131,7 +141,7 @@ def rerank_hybrid_results(
|
||||
|
||||
# Add keyword results with BM25 scores
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
@@ -139,7 +149,7 @@ def rerank_hybrid_results(
|
||||
|
||||
# Add or update with embedding results
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
# Update existing item with embedding score
|
||||
@@ -220,7 +230,7 @@ def rerank_with_forgetting_curve(
|
||||
(keyword_items, False), (embedding_items, True)
|
||||
):
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
existing = combined_items.get(item_id)
|
||||
@@ -266,26 +276,25 @@ def rerank_with_forgetting_curve(
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"):
|
||||
"""Log search query information to file"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
search_type: Type of search (keyword, embedding, hybrid)
|
||||
group_id: Group identifier for filtering
|
||||
limit: Maximum number of results
|
||||
include: List of result types to include
|
||||
log_file: Deprecated parameter, kept for backward compatibility
|
||||
"""
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
log_entry = {
|
||||
"timestamp": timestamp,
|
||||
# "query": query_text,
|
||||
"query": cleaned_query,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id,
|
||||
"limit": limit,
|
||||
"include": include
|
||||
}
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
||||
|
||||
logger.info(f"Search logged: {query_text} ({search_type})")
|
||||
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
f"group_id={group_id}, limit={limit}, include={include}"
|
||||
)
|
||||
|
||||
|
||||
def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
@@ -547,6 +556,7 @@ async def run_hybrid_search(
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
embedding_id: str,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -558,6 +568,7 @@ async def run_hybrid_search(
|
||||
# Start overall timing
|
||||
search_start_time = time.time()
|
||||
latency_metrics = {}
|
||||
logger.info(f"using embedding_id:{embedding_id}...")
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
@@ -610,7 +621,7 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config_dict = get_embedder_config(embedding_id)
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
@@ -759,18 +770,11 @@ async def run_hybrid_search(
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
|
||||
completion_log = {
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"status": "completed",
|
||||
"result_counts": result_counts,
|
||||
"output_file": output_path,
|
||||
"latency_metrics": latency_metrics
|
||||
}
|
||||
|
||||
with open("search_log.txt", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(completion_log, ensure_ascii=False) + "\n")
|
||||
# Log completion using the standard logger
|
||||
logger.info(
|
||||
f"Search completed: query='{query_text}', type={search_type}, "
|
||||
f"result_counts={result_counts}, latency={latency_metrics}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -969,6 +973,7 @@ def main():
|
||||
limit=args.limit,
|
||||
include=args.include,
|
||||
output_path=args.output,
|
||||
embedding_id=config_defs.SELECTED_EMBEDDING_ID,
|
||||
rerank_alpha=args.rerank_alpha,
|
||||
use_forgetting_rerank=args.forgetting_rerank,
|
||||
use_llm_rerank=args.llm_rerank,
|
||||
|
||||
@@ -19,50 +19,50 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 导入各个提取模块
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||
StatementExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
|
||||
TripletExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import (
|
||||
TemporalExtractor,
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
embedding_generation_all,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
|
||||
# 导入各个提取模块
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||
StatementExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import (
|
||||
TemporalExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
|
||||
TripletExtractor,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||||
_write_extracted_result_summary,
|
||||
export_test_input_doc,
|
||||
)
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -96,6 +96,7 @@ class ExtractionOrchestrator:
|
||||
connector: Neo4jConnector,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化流水线编排器
|
||||
@@ -108,6 +109,7 @@ class ExtractionOrchestrator:
|
||||
progress_callback: 进度回调函数
|
||||
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
embedding_id: 嵌入模型ID,如果为 None 则从全局配置获取(向后兼容)
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
@@ -115,6 +117,7 @@ class ExtractionOrchestrator:
|
||||
self.config = config or ExtractionPipelineConfig()
|
||||
self.is_pilot_run = False # 默认非试运行模式
|
||||
self.progress_callback = progress_callback # 保存进度回调函数
|
||||
self.embedding_id = embedding_id # 保存嵌入模型ID
|
||||
|
||||
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||
@@ -420,7 +423,9 @@ class ExtractionOrchestrator:
|
||||
return await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
@@ -434,7 +439,9 @@ class ExtractionOrchestrator:
|
||||
d_idx, stmt_id = statement_metadata[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"陈述句处理异常: {result}")
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
triplet_maps[d_idx][stmt_id] = TripletExtractionResponse(triplets=[], entities=[])
|
||||
else:
|
||||
triplet_maps[d_idx][stmt_id] = result
|
||||
@@ -521,8 +528,8 @@ class ExtractionOrchestrator:
|
||||
temporal_maps[d_idx][stmt_id] = result
|
||||
|
||||
# 为 ATEMPORAL 陈述句添加空的时间范围
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
@@ -629,17 +636,14 @@ class ExtractionOrchestrator:
|
||||
logger.info("开始生成基础嵌入向量(陈述句、分块、对话)")
|
||||
|
||||
try:
|
||||
# 从 runtime.json 获取嵌入模型配置ID
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
embedding_id = config_defs.SELECTED_EMBEDDING_ID
|
||||
|
||||
if not embedding_id:
|
||||
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
|
||||
raise ValueError("未配置嵌入模型ID")
|
||||
# embedding_id is required - no fallback to global variable
|
||||
if not self.embedding_id:
|
||||
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
|
||||
raise ValueError("embedding_id is required but was not provided")
|
||||
|
||||
# 只生成陈述句、分块和对话的嵌入(不包括实体)
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation(
|
||||
dialog_data_list, embedding_id
|
||||
dialog_data_list, self.embedding_id
|
||||
)
|
||||
|
||||
# 统计生成结果
|
||||
@@ -683,17 +687,14 @@ class ExtractionOrchestrator:
|
||||
logger.info("开始生成实体嵌入向量")
|
||||
|
||||
try:
|
||||
# 从 runtime.json 获取嵌入模型配置ID
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
embedding_id = config_defs.SELECTED_EMBEDDING_ID
|
||||
|
||||
if not embedding_id:
|
||||
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
|
||||
# embedding_id is required - no fallback to global variable
|
||||
if not self.embedding_id:
|
||||
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
|
||||
return triplet_maps
|
||||
|
||||
# 生成实体嵌入
|
||||
updated_triplet_maps = await generate_entity_embeddings_from_triplets(
|
||||
triplet_maps, embedding_id
|
||||
triplet_maps, self.embedding_id
|
||||
)
|
||||
|
||||
logger.info("实体嵌入生成完成")
|
||||
@@ -1086,7 +1087,9 @@ class ExtractionOrchestrator:
|
||||
if self.is_pilot_run:
|
||||
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
|
||||
# 只执行第一层去重
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
@@ -1608,8 +1611,8 @@ async def get_chunked_dialogs(
|
||||
包含分块的 DialogData 对象列表
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
|
||||
# 加载测试数据
|
||||
testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json")
|
||||
@@ -1671,7 +1674,9 @@ async def get_chunked_dialogs(
|
||||
)
|
||||
|
||||
# 创建分块器并处理对话
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
@@ -1718,7 +1723,9 @@ def preprocess_data(
|
||||
经过清洗转换后的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 数据预处理 ===")
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
preprocessor = DataPreprocessor()
|
||||
try:
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
||||
@@ -1749,7 +1756,9 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
raise ValueError("预处理数据为空,无法进行分块")
|
||||
|
||||
all_chunked_dialogs: List[DialogData] = []
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
|
||||
for dialog_data in data:
|
||||
chunker = DialogueChunker(chunker_strategy, llm_client=llm_client)
|
||||
@@ -1811,7 +1820,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
# 步骤2: 语义剪枝
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
pruner = SemanticPruner(llm_client=llm_client)
|
||||
|
||||
# 记录单对话场景下剪枝前的消息数量
|
||||
@@ -1834,7 +1845,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
# 保存剪枝后的数据
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
pruned_output_path = settings.get_memory_output_path("pruned_data.json")
|
||||
dp = DataPreprocessor(output_file_path=pruned_output_path)
|
||||
dp.save_data(preprocessed_data, output_path=pruned_output_path)
|
||||
|
||||
@@ -1,447 +0,0 @@
|
||||
|
||||
# TODO hybrid_chatbot.py 是一个独立的GUI演示应用,不是核心功能的一部分,可以考虑删除
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
import tkinter as tk
|
||||
from tkinter import scrolledtext, messagebox
|
||||
import threading
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
# Import our hybrid search functionality
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.config_models import LLMConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class HybridSearchChatbot:
|
||||
def __init__(self):
|
||||
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
# Chat history
|
||||
self.chat_history = []
|
||||
|
||||
# Search configuration
|
||||
self.search_config = {
|
||||
"group_id": "group_wyl_25",
|
||||
"limit": 10,
|
||||
"include": ["statements", "chunks", "entities","summaries"],
|
||||
# "include": ["statements", "dialogues", "entities"],
|
||||
"rerank_alpha": 0.6
|
||||
}
|
||||
|
||||
# Setup GUI
|
||||
self.setup_gui()
|
||||
|
||||
def setup_gui(self):
|
||||
"""Setup the GUI interface"""
|
||||
self.root = tk.Tk()
|
||||
self.root.title("Hybrid Search Chatbot")
|
||||
self.root.geometry("800x600")
|
||||
|
||||
# Chat display area
|
||||
self.chat_display = scrolledtext.ScrolledText(
|
||||
self.root,
|
||||
wrap=tk.WORD,
|
||||
width=80,
|
||||
height=25,
|
||||
state=tk.DISABLED
|
||||
)
|
||||
self.chat_display.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
|
||||
|
||||
# Input frame
|
||||
input_frame = tk.Frame(self.root)
|
||||
input_frame.pack(padx=10, pady=5, fill=tk.X)
|
||||
|
||||
# User input
|
||||
self.user_input = tk.Entry(input_frame, font=("Arial", 12))
|
||||
self.user_input.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5))
|
||||
self.user_input.bind("<Return>", self.on_send_message)
|
||||
|
||||
# Send button
|
||||
self.send_button = tk.Button(
|
||||
input_frame,
|
||||
text="发送",
|
||||
command=self.on_send_message,
|
||||
font=("Arial", 12)
|
||||
)
|
||||
self.send_button.pack(side=tk.RIGHT)
|
||||
|
||||
# Status frame
|
||||
status_frame = tk.Frame(self.root)
|
||||
status_frame.pack(padx=10, pady=5, fill=tk.X)
|
||||
|
||||
# Status label
|
||||
self.status_label = tk.Label(
|
||||
status_frame,
|
||||
text="就绪",
|
||||
font=("Arial", 10),
|
||||
anchor="w"
|
||||
)
|
||||
self.status_label.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
||||
|
||||
# Search config button
|
||||
config_button = tk.Button(
|
||||
status_frame,
|
||||
text="搜索配置",
|
||||
command=self.show_config_dialog,
|
||||
font=("Arial", 10)
|
||||
)
|
||||
config_button.pack(side=tk.RIGHT)
|
||||
|
||||
# Add welcome message
|
||||
self.add_message("系统", "欢迎使用混合搜索聊天机器人!我可以基于知识图谱中的信息回答您的问题。")
|
||||
|
||||
def add_message(self, sender: str, message: str, metadata: Dict = None):
|
||||
"""Add a message to the chat display"""
|
||||
self.chat_display.config(state=tk.NORMAL)
|
||||
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
# Add sender and timestamp
|
||||
self.chat_display.insert(tk.END, f"[{timestamp}] {sender}:\n", "sender")
|
||||
|
||||
# Add message content
|
||||
self.chat_display.insert(tk.END, f"{message}\n", "message")
|
||||
|
||||
# Add metadata if available
|
||||
if metadata:
|
||||
self.chat_display.insert(tk.END, f" {metadata}\n", "metadata")
|
||||
|
||||
self.chat_display.insert(tk.END, "\n")
|
||||
self.chat_display.config(state=tk.DISABLED)
|
||||
self.chat_display.see(tk.END)
|
||||
|
||||
# Configure text tags for styling
|
||||
self.chat_display.tag_config("sender", foreground="blue", font=("Arial", 10, "bold"))
|
||||
self.chat_display.tag_config("message", foreground="black", font=("Arial", 10))
|
||||
self.chat_display.tag_config("metadata", foreground="gray", font=("Arial", 8))
|
||||
|
||||
def show_config_dialog(self):
|
||||
"""Show search configuration dialog"""
|
||||
config_window = tk.Toplevel(self.root)
|
||||
config_window.title("搜索配置")
|
||||
config_window.geometry("400x600")
|
||||
config_window.transient(self.root)
|
||||
config_window.grab_set()
|
||||
|
||||
# Current configuration display
|
||||
current_config_frame = tk.Frame(config_window)
|
||||
current_config_frame.pack(pady=10, padx=10, fill=tk.X)
|
||||
tk.Label(current_config_frame, text="当前配置:", font=("Arial", 10, "bold")).pack(anchor="w")
|
||||
current_text = f"Alpha: {self.search_config['rerank_alpha']}, 限制: {self.search_config['limit']}, 目标: {', '.join(self.search_config['include'])}"
|
||||
tk.Label(current_config_frame, text=current_text, font=("Arial", 9), fg="blue").pack(anchor="w")
|
||||
|
||||
# Alpha parameter
|
||||
tk.Label(config_window, text="重排权重 (Alpha):").pack(pady=(10, 5))
|
||||
alpha_var = tk.DoubleVar(value=self.search_config["rerank_alpha"])
|
||||
alpha_scale = tk.Scale(
|
||||
config_window,
|
||||
from_=0.0,
|
||||
to=1.0,
|
||||
resolution=0.1,
|
||||
orient=tk.HORIZONTAL,
|
||||
variable=alpha_var
|
||||
)
|
||||
alpha_scale.pack(pady=5, padx=20, fill=tk.X)
|
||||
tk.Label(config_window, text="0.0=纯语义搜索, 1.0=纯关键词搜索", font=("Arial", 8)).pack()
|
||||
|
||||
# Limit parameter
|
||||
tk.Label(config_window, text="搜索结果数量:").pack(pady=(20, 5))
|
||||
limit_var = tk.IntVar(value=self.search_config["limit"])
|
||||
limit_spinbox = tk.Spinbox(
|
||||
config_window,
|
||||
from_=1,
|
||||
to=50,
|
||||
textvariable=limit_var,
|
||||
width=10
|
||||
)
|
||||
limit_spinbox.pack(pady=5)
|
||||
|
||||
# Include options
|
||||
tk.Label(config_window, text="搜索目标:").pack(pady=(20, 5))
|
||||
include_frame = tk.Frame(config_window)
|
||||
include_frame.pack(pady=5)
|
||||
|
||||
include_vars = {}
|
||||
for option in ["statements", "chunks", "entities","summaries"]:
|
||||
var = tk.BooleanVar(value=option in self.search_config["include"])
|
||||
include_vars[option] = var
|
||||
tk.Checkbutton(
|
||||
include_frame,
|
||||
text=option,
|
||||
variable=var
|
||||
).pack(side=tk.LEFT, padx=10)
|
||||
|
||||
# Buttons
|
||||
button_frame = tk.Frame(config_window)
|
||||
button_frame.pack(pady=20)
|
||||
|
||||
def save_config():
|
||||
try:
|
||||
# Validate inputs
|
||||
alpha_value = alpha_var.get()
|
||||
limit_value = limit_var.get()
|
||||
include_list = [
|
||||
option for option, var in include_vars.items() if var.get()
|
||||
]
|
||||
|
||||
# Check if at least one search target is selected
|
||||
if not include_list:
|
||||
messagebox.showerror("配置错误", "请至少选择一个搜索目标!")
|
||||
return
|
||||
|
||||
# Update configuration
|
||||
self.search_config["rerank_alpha"] = alpha_value
|
||||
self.search_config["limit"] = limit_value
|
||||
self.search_config["include"] = include_list
|
||||
|
||||
config_window.destroy()
|
||||
self.add_message("系统",
|
||||
f"配置已更新: Alpha={alpha_value:.1f}, 限制={limit_value}, 目标={', '.join(include_list)}")
|
||||
|
||||
except Exception as e:
|
||||
messagebox.showerror("配置错误", f"保存配置时出错: {str(e)}")
|
||||
print(f"Config save error: {e}") # Debug output
|
||||
|
||||
tk.Button(button_frame, text="保存", command=save_config).pack(side=tk.LEFT, padx=5)
|
||||
tk.Button(button_frame, text="取消", command=config_window.destroy).pack(side=tk.LEFT, padx=5)
|
||||
|
||||
def on_send_message(self, event=None):
|
||||
"""Handle sending a message"""
|
||||
user_message = self.user_input.get().strip()
|
||||
if not user_message:
|
||||
return
|
||||
|
||||
# Clear input
|
||||
self.user_input.delete(0, tk.END)
|
||||
|
||||
# Add user message to display
|
||||
self.add_message("用户", user_message)
|
||||
|
||||
# Disable send button and show processing status
|
||||
self.send_button.config(state=tk.DISABLED)
|
||||
self.status_label.config(text="正在搜索和生成回复...")
|
||||
|
||||
# Process message in background thread
|
||||
threading.Thread(
|
||||
target=self.process_message_async,
|
||||
args=(user_message,),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
def process_message_async(self, user_message: str):
|
||||
"""Process message asynchronously"""
|
||||
try:
|
||||
# Run the async processing
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
response, metadata = loop.run_until_complete(
|
||||
self.process_message(user_message)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
# Update GUI in main thread
|
||||
self.root.after(0, self.on_response_ready, response, metadata)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理消息时出错: {str(e)}"
|
||||
self.root.after(0, self.on_error, error_msg)
|
||||
|
||||
async def process_message(self, user_message: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Process user message with hybrid search"""
|
||||
start_time = time.time()
|
||||
|
||||
# Perform hybrid search
|
||||
search_start = time.time()
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=user_message,
|
||||
search_type="hybrid",
|
||||
group_id=self.search_config["group_id"],
|
||||
limit=self.search_config["limit"],
|
||||
include=self.search_config["include"],
|
||||
output_path=None,
|
||||
rerank_alpha=self.search_config["rerank_alpha"]
|
||||
)
|
||||
search_time = time.time() - search_start
|
||||
|
||||
# Extract relevant information from search results
|
||||
context_info = self.extract_context_from_search(search_results)
|
||||
|
||||
# Generate response using LLM
|
||||
llm_start = time.time()
|
||||
response = await self.generate_response(user_message, context_info)
|
||||
llm_time = time.time() - llm_start
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"搜索时间": f"{search_time:.2f}s",
|
||||
"生成时间": f"{llm_time:.2f}s",
|
||||
"总时间": f"{total_time:.2f}s",
|
||||
"搜索结果": self.get_search_summary(search_results),
|
||||
"重排权重": self.search_config["rerank_alpha"]
|
||||
}
|
||||
|
||||
return response, metadata
|
||||
|
||||
def extract_context_from_search(self, search_results: Dict) -> str:
|
||||
"""Extract context information from search results"""
|
||||
if not search_results:
|
||||
return "未找到相关信息。"
|
||||
|
||||
context_parts = []
|
||||
|
||||
# Get reranked results if available, otherwise use individual results
|
||||
if "reranked_results" in search_results:
|
||||
results = search_results["reranked_results"]
|
||||
else:
|
||||
results = {}
|
||||
for key in ["keyword_search", "embedding_search"]:
|
||||
if key in search_results:
|
||||
for category, items in search_results[key].items():
|
||||
if category not in results:
|
||||
results[category] = []
|
||||
results[category].extend(items)
|
||||
|
||||
# Extract statements
|
||||
if "statements" in results and results["statements"]:
|
||||
statements = results["statements"][:5] # Top 5
|
||||
context_parts.append("相关陈述:")
|
||||
for i, stmt in enumerate(statements, 1):
|
||||
content = stmt.get("statement", "")
|
||||
score = stmt.get("combined_score", stmt.get("score", 0))
|
||||
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
|
||||
|
||||
# Extract chunks
|
||||
if "chunks" in results and results["chunks"]:
|
||||
chunks = results["chunks"][:3] # Top 3
|
||||
context_parts.append("\n相关对话:")
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
content = chunk.get("content", "")
|
||||
score = chunk.get("combined_score", chunk.get("score", 0))
|
||||
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
|
||||
|
||||
# Extract entities
|
||||
if "entities" in results and results["entities"]:
|
||||
entities = results["entities"][:5] # Top 5
|
||||
context_parts.append("\n相关实体:")
|
||||
entity_names = [ent.get("name", "") for ent in entities]
|
||||
context_parts.append(", ".join(entity_names))
|
||||
|
||||
return "\n".join(context_parts) if context_parts else "未找到相关信息。"
|
||||
|
||||
def get_search_summary(self, search_results: Dict) -> str:
|
||||
"""Get a summary of search results"""
|
||||
if not search_results:
|
||||
return "无结果"
|
||||
|
||||
summary_parts = []
|
||||
|
||||
if "combined_summary" in search_results:
|
||||
summary = search_results["combined_summary"]
|
||||
if "total_reranked_results" in summary:
|
||||
summary_parts.append(f"重排结果: {summary['total_reranked_results']}")
|
||||
if "total_keyword_results" in summary:
|
||||
summary_parts.append(f"关键词: {summary['total_keyword_results']}")
|
||||
if "total_embedding_results" in summary:
|
||||
summary_parts.append(f"语义: {summary['total_embedding_results']}")
|
||||
|
||||
return ", ".join(summary_parts) if summary_parts else "有结果"
|
||||
|
||||
async def generate_response(self, user_message: str, context: str) -> str:
|
||||
"""Generate response using LLM"""
|
||||
system_prompt = f"""你是一个智能助手,基于知识图谱中的信息回答用户问题。
|
||||
|
||||
以下是从知识图谱中检索到的相关信息:
|
||||
{context}
|
||||
|
||||
请基于这些信息回答用户的问题。如果信息不足,请诚实地说明。回答要自然、友好,并且准确。"""
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
response = self.llm_client.chat(
|
||||
messages=messages,
|
||||
)
|
||||
print(response)
|
||||
# Extract content from various possible response types
|
||||
# 1) LangChain AIMessage or similar object with `.content`
|
||||
if hasattr(response, 'content'):
|
||||
return getattr(response, 'content')
|
||||
|
||||
# 2) OpenAI-style response with `.choices`
|
||||
if hasattr(response, 'choices') and response.choices:
|
||||
first_choice = response.choices[0]
|
||||
# Newer clients may have `.message.content`, some have `.content` directly
|
||||
if hasattr(first_choice, 'message') and hasattr(first_choice.message, 'content'):
|
||||
return first_choice.message.content
|
||||
if hasattr(first_choice, 'content'):
|
||||
return first_choice.content
|
||||
|
||||
# 3) Dict-like responses
|
||||
if isinstance(response, dict):
|
||||
if 'content' in response:
|
||||
return response['content']
|
||||
if 'choices' in response and response['choices']:
|
||||
ch = response['choices'][0]
|
||||
if isinstance(ch, dict):
|
||||
if 'message' in ch and 'content' in ch['message']:
|
||||
return ch['message']['content']
|
||||
if 'content' in ch:
|
||||
return ch['content']
|
||||
|
||||
# 4) Fallback: if it's a plain string
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
|
||||
# Default fallback
|
||||
return "抱歉,我无法生成回复。"
|
||||
|
||||
except Exception as e:
|
||||
return f"生成回复时出错: {str(e)}"
|
||||
|
||||
def on_response_ready(self, response: str, metadata: Dict[str, Any]):
|
||||
"""Handle when response is ready"""
|
||||
self.add_message("助手", response, metadata)
|
||||
self.send_button.config(state=tk.NORMAL)
|
||||
self.status_label.config(text="就绪")
|
||||
self.user_input.focus()
|
||||
|
||||
def on_error(self, error_message: str):
|
||||
"""Handle errors"""
|
||||
self.add_message("系统", f" {error_message}")
|
||||
self.send_button.config(state=tk.NORMAL)
|
||||
self.status_label.config(text="就绪")
|
||||
self.user_input.focus()
|
||||
|
||||
def run(self):
|
||||
"""Start the chatbot"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the chatbot"""
|
||||
try:
|
||||
chatbot = HybridSearchChatbot()
|
||||
chatbot.run()
|
||||
except Exception as e:
|
||||
print(f"启动聊天机器人时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,408 +1,408 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""混合搜索策略
|
||||
# # -*- coding: utf-8 -*-
|
||||
# """混合搜索策略
|
||||
|
||||
结合关键词搜索和语义搜索的混合检索方法。
|
||||
支持结果重排序和遗忘曲线加权。
|
||||
"""
|
||||
# 结合关键词搜索和语义搜索的混合检索方法。
|
||||
# 支持结果重排序和遗忘曲线加权。
|
||||
# """
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import math
|
||||
from datetime import datetime
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
# from typing import List, Dict, Any, Optional
|
||||
# import math
|
||||
# from datetime import datetime
|
||||
# from app.core.logging_config import get_memory_logger
|
||||
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
# from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
# from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
# from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
# from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
# logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class HybridSearchStrategy(SearchStrategy):
|
||||
"""混合搜索策略
|
||||
# class HybridSearchStrategy(SearchStrategy):
|
||||
# """混合搜索策略
|
||||
|
||||
结合关键词搜索和语义搜索的优势:
|
||||
- 关键词搜索:精确匹配,适合已知术语
|
||||
- 语义搜索:语义理解,适合概念查询
|
||||
- 混合重排序:综合两种搜索的结果
|
||||
- 遗忘曲线:根据时间衰减调整相关性
|
||||
"""
|
||||
# 结合关键词搜索和语义搜索的优势:
|
||||
# - 关键词搜索:精确匹配,适合已知术语
|
||||
# - 语义搜索:语义理解,适合概念查询
|
||||
# - 混合重排序:综合两种搜索的结果
|
||||
# - 遗忘曲线:根据时间衰减调整相关性
|
||||
# """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
embedder_client: Optional[OpenAIEmbedderClient] = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
forgetting_config: Optional[ForgettingEngineConfig] = None
|
||||
):
|
||||
"""初始化混合搜索策略
|
||||
# def __init__(
|
||||
# self,
|
||||
# connector: Optional[Neo4jConnector] = None,
|
||||
# embedder_client: Optional[OpenAIEmbedderClient] = None,
|
||||
# alpha: float = 0.6,
|
||||
# use_forgetting_curve: bool = False,
|
||||
# forgetting_config: Optional[ForgettingEngineConfig] = None
|
||||
# ):
|
||||
# """初始化混合搜索策略
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
embedder_client: 嵌入模型客户端
|
||||
alpha: BM25分数权重(0.0-1.0),1-alpha为嵌入分数权重
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
forgetting_config: 遗忘引擎配置
|
||||
"""
|
||||
self.connector = connector
|
||||
self.embedder_client = embedder_client
|
||||
self.alpha = alpha
|
||||
self.use_forgetting_curve = use_forgetting_curve
|
||||
self.forgetting_config = forgetting_config or ForgettingEngineConfig()
|
||||
self._owns_connector = connector is None
|
||||
# Args:
|
||||
# connector: Neo4j连接器
|
||||
# embedder_client: 嵌入模型客户端
|
||||
# alpha: BM25分数权重(0.0-1.0),1-alpha为嵌入分数权重
|
||||
# use_forgetting_curve: 是否使用遗忘曲线
|
||||
# forgetting_config: 遗忘引擎配置
|
||||
# """
|
||||
# self.connector = connector
|
||||
# self.embedder_client = embedder_client
|
||||
# self.alpha = alpha
|
||||
# self.use_forgetting_curve = use_forgetting_curve
|
||||
# self.forgetting_config = forgetting_config or ForgettingEngineConfig()
|
||||
# self._owns_connector = connector is None
|
||||
|
||||
# 创建子策略
|
||||
self.keyword_strategy = KeywordSearchStrategy(connector=connector)
|
||||
self.semantic_strategy = SemanticSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
# # 创建子策略
|
||||
# self.keyword_strategy = KeywordSearchStrategy(connector=connector)
|
||||
# self.semantic_strategy = SemanticSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client
|
||||
# )
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
if self._owns_connector:
|
||||
self.connector = Neo4jConnector()
|
||||
self.keyword_strategy.connector = self.connector
|
||||
self.semantic_strategy.connector = self.connector
|
||||
return self
|
||||
# async def __aenter__(self):
|
||||
# """异步上下文管理器入口"""
|
||||
# if self._owns_connector:
|
||||
# self.connector = Neo4jConnector()
|
||||
# self.keyword_strategy.connector = self.connector
|
||||
# self.semantic_strategy.connector = self.connector
|
||||
# return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
if self._owns_connector and self.connector:
|
||||
await self.connector.close()
|
||||
# async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# """异步上下文管理器出口"""
|
||||
# if self._owns_connector and self.connector:
|
||||
# await self.connector.close()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> SearchResult:
|
||||
"""执行混合搜索
|
||||
# async def search(
|
||||
# self,
|
||||
# query_text: str,
|
||||
# group_id: Optional[str] = None,
|
||||
# limit: int = 50,
|
||||
# include: Optional[List[str]] = None,
|
||||
# **kwargs
|
||||
# ) -> SearchResult:
|
||||
# """执行混合搜索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
# Args:
|
||||
# query_text: 查询文本
|
||||
# group_id: 可选的组ID过滤
|
||||
# limit: 每个类别的最大结果数
|
||||
# include: 要包含的搜索类别列表
|
||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
# Returns:
|
||||
# SearchResult: 搜索结果对象
|
||||
# """
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
|
||||
# 从kwargs中获取参数
|
||||
alpha = kwargs.get("alpha", self.alpha)
|
||||
use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
|
||||
# # 从kwargs中获取参数
|
||||
# alpha = kwargs.get("alpha", self.alpha)
|
||||
# use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
# # 获取有效的搜索类别
|
||||
# include_list = self._get_include_list(include)
|
||||
|
||||
try:
|
||||
# 并行执行关键词搜索和语义搜索
|
||||
keyword_result = await self.keyword_strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
# try:
|
||||
# # 并行执行关键词搜索和语义搜索
|
||||
# keyword_result = await self.keyword_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
semantic_result = await self.semantic_strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
# semantic_result = await self.semantic_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# 重排序结果
|
||||
if use_forgetting:
|
||||
reranked_results = self._rerank_with_forgetting_curve(
|
||||
keyword_result=keyword_result,
|
||||
semantic_result=semantic_result,
|
||||
alpha=alpha,
|
||||
limit=limit
|
||||
)
|
||||
else:
|
||||
reranked_results = self._rerank_hybrid_results(
|
||||
keyword_result=keyword_result,
|
||||
semantic_result=semantic_result,
|
||||
alpha=alpha,
|
||||
limit=limit
|
||||
)
|
||||
# # 重排序结果
|
||||
# if use_forgetting:
|
||||
# reranked_results = self._rerank_with_forgetting_curve(
|
||||
# keyword_result=keyword_result,
|
||||
# semantic_result=semantic_result,
|
||||
# alpha=alpha,
|
||||
# limit=limit
|
||||
# )
|
||||
# else:
|
||||
# reranked_results = self._rerank_hybrid_results(
|
||||
# keyword_result=keyword_result,
|
||||
# semantic_result=semantic_result,
|
||||
# alpha=alpha,
|
||||
# limit=limit
|
||||
# )
|
||||
|
||||
# 创建元数据
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="hybrid",
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
include=include_list,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting
|
||||
)
|
||||
# # 创建元数据
|
||||
# metadata = self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# include=include_list,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting
|
||||
# )
|
||||
|
||||
# 添加结果统计
|
||||
metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
|
||||
metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
|
||||
metadata["total_keyword_results"] = keyword_result.total_results()
|
||||
metadata["total_semantic_results"] = semantic_result.total_results()
|
||||
metadata["total_reranked_results"] = reranked_results.total_results()
|
||||
# # 添加结果统计
|
||||
# metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
|
||||
# metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
|
||||
# metadata["total_keyword_results"] = keyword_result.total_results()
|
||||
# metadata["total_semantic_results"] = semantic_result.total_results()
|
||||
# metadata["total_reranked_results"] = reranked_results.total_results()
|
||||
|
||||
reranked_results.metadata = metadata
|
||||
# reranked_results.metadata = metadata
|
||||
|
||||
logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
|
||||
return reranked_results
|
||||
# logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
|
||||
# return reranked_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"混合搜索失败: {e}", exc_info=True)
|
||||
# 返回空结果但包含错误信息
|
||||
return SearchResult(
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="hybrid",
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
# except Exception as e:
|
||||
# logger.error(f"混合搜索失败: {e}", exc_info=True)
|
||||
# # 返回空结果但包含错误信息
|
||||
# return SearchResult(
|
||||
# metadata=self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# limit=limit,
|
||||
# error=str(e)
|
||||
# )
|
||||
# )
|
||||
|
||||
def _normalize_scores(
|
||||
self,
|
||||
results: List[Dict[str, Any]],
|
||||
score_field: str = "score"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""使用z-score标准化和sigmoid转换归一化分数
|
||||
# def _normalize_scores(
|
||||
# self,
|
||||
# results: List[Dict[str, Any]],
|
||||
# score_field: str = "score"
|
||||
# ) -> List[Dict[str, Any]]:
|
||||
# """使用z-score标准化和sigmoid转换归一化分数
|
||||
|
||||
Args:
|
||||
results: 结果列表
|
||||
score_field: 分数字段名
|
||||
# Args:
|
||||
# results: 结果列表
|
||||
# score_field: 分数字段名
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 归一化后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return results
|
||||
# Returns:
|
||||
# List[Dict[str, Any]]: 归一化后的结果列表
|
||||
# """
|
||||
# if not results:
|
||||
# return results
|
||||
|
||||
# 提取分数
|
||||
scores = []
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item.get(score_field)
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
scores.append(0.0)
|
||||
# # 提取分数
|
||||
# scores = []
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# score = item.get(score_field)
|
||||
# if score is not None and isinstance(score, (int, float)):
|
||||
# scores.append(float(score))
|
||||
# else:
|
||||
# scores.append(0.0)
|
||||
|
||||
if not scores or len(scores) == 1:
|
||||
# 单个分数或无分数,设置为1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
return results
|
||||
# if not scores or len(scores) == 1:
|
||||
# # 单个分数或无分数,设置为1.0
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# item[f"normalized_{score_field}"] = 1.0
|
||||
# return results
|
||||
|
||||
# 计算均值和标准差
|
||||
mean_score = sum(scores) / len(scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
std_dev = math.sqrt(variance)
|
||||
# # 计算均值和标准差
|
||||
# mean_score = sum(scores) / len(scores)
|
||||
# variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
# std_dev = math.sqrt(variance)
|
||||
|
||||
if std_dev == 0:
|
||||
# 所有分数相同,设置为1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
else:
|
||||
# z-score标准化 + sigmoid转换
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item[score_field]
|
||||
if score is None or not isinstance(score, (int, float)):
|
||||
score = 0.0
|
||||
z_score = (score - mean_score) / std_dev
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
# if std_dev == 0:
|
||||
# # 所有分数相同,设置为1.0
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# item[f"normalized_{score_field}"] = 1.0
|
||||
# else:
|
||||
# # z-score标准化 + sigmoid转换
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# score = item[score_field]
|
||||
# if score is None or not isinstance(score, (int, float)):
|
||||
# score = 0.0
|
||||
# z_score = (score - mean_score) / std_dev
|
||||
# normalized = 1 / (1 + math.exp(-z_score))
|
||||
# item[f"normalized_{score_field}"] = normalized
|
||||
|
||||
return results
|
||||
# return results
|
||||
|
||||
def _rerank_hybrid_results(
|
||||
self,
|
||||
keyword_result: SearchResult,
|
||||
semantic_result: SearchResult,
|
||||
alpha: float,
|
||||
limit: int
|
||||
) -> SearchResult:
|
||||
"""重排序混合搜索结果
|
||||
# def _rerank_hybrid_results(
|
||||
# self,
|
||||
# keyword_result: SearchResult,
|
||||
# semantic_result: SearchResult,
|
||||
# alpha: float,
|
||||
# limit: int
|
||||
# ) -> SearchResult:
|
||||
# """重排序混合搜索结果
|
||||
|
||||
Args:
|
||||
keyword_result: 关键词搜索结果
|
||||
semantic_result: 语义搜索结果
|
||||
alpha: BM25分数权重
|
||||
limit: 结果限制
|
||||
# Args:
|
||||
# keyword_result: 关键词搜索结果
|
||||
# semantic_result: 语义搜索结果
|
||||
# alpha: BM25分数权重
|
||||
# limit: 结果限制
|
||||
|
||||
Returns:
|
||||
SearchResult: 重排序后的结果
|
||||
"""
|
||||
reranked_data = {}
|
||||
# Returns:
|
||||
# SearchResult: 重排序后的结果
|
||||
# """
|
||||
# reranked_data = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
keyword_items = getattr(keyword_result, category, [])
|
||||
semantic_items = getattr(semantic_result, category, [])
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# keyword_items = getattr(keyword_result, category, [])
|
||||
# semantic_items = getattr(semantic_result, category, [])
|
||||
|
||||
# 归一化分数
|
||||
keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
# # 归一化分数
|
||||
# keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
# semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
|
||||
# 合并结果
|
||||
combined_items = {}
|
||||
# # 合并结果
|
||||
# combined_items = {}
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# # 添加关键词结果
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
|
||||
# 添加或更新语义结果
|
||||
for item in semantic_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# # 添加或更新语义结果
|
||||
# for item in semantic_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# 计算组合分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = item.get("bm25_score", 0)
|
||||
embedding_score = item.get("embedding_score", 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
item["combined_score"] = combined_score
|
||||
# # 计算组合分数
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
|
||||
# 排序并限制结果
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
# # 排序并限制结果
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
|
||||
reranked_data[category] = sorted_items
|
||||
# reranked_data[category] = sorted_items
|
||||
|
||||
return SearchResult(
|
||||
statements=reranked_data.get("statements", []),
|
||||
chunks=reranked_data.get("chunks", []),
|
||||
entities=reranked_data.get("entities", []),
|
||||
summaries=reranked_data.get("summaries", [])
|
||||
)
|
||||
# return SearchResult(
|
||||
# statements=reranked_data.get("statements", []),
|
||||
# chunks=reranked_data.get("chunks", []),
|
||||
# entities=reranked_data.get("entities", []),
|
||||
# summaries=reranked_data.get("summaries", [])
|
||||
# )
|
||||
|
||||
def _parse_datetime(self, value: Any) -> Optional[datetime]:
|
||||
"""解析日期时间字符串"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
s = value.strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(s)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
# def _parse_datetime(self, value: Any) -> Optional[datetime]:
|
||||
# """解析日期时间字符串"""
|
||||
# if value is None:
|
||||
# return None
|
||||
# if isinstance(value, datetime):
|
||||
# return value
|
||||
# if isinstance(value, str):
|
||||
# s = value.strip()
|
||||
# if not s:
|
||||
# return None
|
||||
# try:
|
||||
# return datetime.fromisoformat(s)
|
||||
# except Exception:
|
||||
# return None
|
||||
# return None
|
||||
|
||||
def _rerank_with_forgetting_curve(
|
||||
self,
|
||||
keyword_result: SearchResult,
|
||||
semantic_result: SearchResult,
|
||||
alpha: float,
|
||||
limit: int
|
||||
) -> SearchResult:
|
||||
"""使用遗忘曲线重排序混合搜索结果
|
||||
# def _rerank_with_forgetting_curve(
|
||||
# self,
|
||||
# keyword_result: SearchResult,
|
||||
# semantic_result: SearchResult,
|
||||
# alpha: float,
|
||||
# limit: int
|
||||
# ) -> SearchResult:
|
||||
# """使用遗忘曲线重排序混合搜索结果
|
||||
|
||||
Args:
|
||||
keyword_result: 关键词搜索结果
|
||||
semantic_result: 语义搜索结果
|
||||
alpha: BM25分数权重
|
||||
limit: 结果限制
|
||||
# Args:
|
||||
# keyword_result: 关键词搜索结果
|
||||
# semantic_result: 语义搜索结果
|
||||
# alpha: BM25分数权重
|
||||
# limit: 结果限制
|
||||
|
||||
Returns:
|
||||
SearchResult: 重排序后的结果
|
||||
"""
|
||||
engine = ForgettingEngine(self.forgetting_config)
|
||||
now_dt = datetime.now()
|
||||
# Returns:
|
||||
# SearchResult: 重排序后的结果
|
||||
# """
|
||||
# engine = ForgettingEngine(self.forgetting_config)
|
||||
# now_dt = datetime.now()
|
||||
|
||||
reranked_data = {}
|
||||
# reranked_data = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
keyword_items = getattr(keyword_result, category, [])
|
||||
semantic_items = getattr(semantic_result, category, [])
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# keyword_items = getattr(keyword_result, category, [])
|
||||
# semantic_items = getattr(semantic_result, category, [])
|
||||
|
||||
# 归一化分数
|
||||
keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
# # 归一化分数
|
||||
# keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
# semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
|
||||
# 合并结果
|
||||
combined_items = {}
|
||||
# # 合并结果
|
||||
# combined_items = {}
|
||||
|
||||
for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid")
|
||||
if not item_id:
|
||||
continue
|
||||
# for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if not item_id:
|
||||
# continue
|
||||
|
||||
if item_id not in combined_items:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# if item_id not in combined_items:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
|
||||
if is_embedding:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# 计算分数并应用遗忘权重
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# # 计算分数并应用遗忘权重
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
|
||||
# 计算时间衰减
|
||||
dt = self._parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
# # 计算时间衰减
|
||||
# dt = self._parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
memory_strength = 1.0 # 默认强度
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
# memory_strength = 1.0 # 默认强度
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days,
|
||||
# memory_strength=memory_strength
|
||||
# )
|
||||
|
||||
final_score = combined_score * forgetting_weight
|
||||
item["combined_score"] = final_score
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["time_elapsed_days"] = time_elapsed_days
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
# item["forgetting_weight"] = forgetting_weight
|
||||
# item["time_elapsed_days"] = time_elapsed_days
|
||||
|
||||
# 排序并限制结果
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
# # 排序并限制结果
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
|
||||
reranked_data[category] = sorted_items
|
||||
# reranked_data[category] = sorted_items
|
||||
|
||||
return SearchResult(
|
||||
statements=reranked_data.get("statements", []),
|
||||
chunks=reranked_data.get("chunks", []),
|
||||
entities=reranked_data.get("entities", []),
|
||||
summaries=reranked_data.get("summaries", [])
|
||||
)
|
||||
# return SearchResult(
|
||||
# statements=reranked_data.get("statements", []),
|
||||
# chunks=reranked_data.get("chunks", []),
|
||||
# entities=reranked_data.get("entities", []),
|
||||
# summaries=reranked_data.get("summaries", [])
|
||||
# )
|
||||
|
||||
@@ -1,445 +0,0 @@
|
||||
# Memory 模块工具函数文档
|
||||
|
||||
本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
app/core/memory/utils/
|
||||
├── __init__.py # 包初始化文件,导出所有公共接口
|
||||
├── README.md # 本文档
|
||||
├── config/ # 配置管理模块
|
||||
│ ├── __init__.py # 配置模块初始化
|
||||
│ ├── config_utils.py # 配置管理工具
|
||||
│ ├── definitions.py # 全局定义和常量
|
||||
│ ├── overrides.py # 运行时配置覆写
|
||||
│ ├── get_data.py # 数据获取工具
|
||||
│ ├── litellm_config.py # LiteLLM 配置和监控
|
||||
│ └── config_optimization.py # 配置优化工具
|
||||
├── log/ # 日志管理模块
|
||||
│ ├── __init__.py # 日志模块初始化
|
||||
│ ├── logging_utils.py # 日志工具
|
||||
│ └── audit_logger.py # 审计日志
|
||||
├── prompt/ # 提示词管理模块
|
||||
│ ├── __init__.py # 提示词模块初始化
|
||||
│ ├── prompt_utils.py # 提示词渲染工具
|
||||
│ ├── template_render.py # 模板渲染工具
|
||||
│ └── prompts/ # Jinja2 提示词模板目录
|
||||
│ ├── entity_dedup.jinja2 # 实体去重提示词
|
||||
│ ├── extract_statement.jinja2 # 陈述句提取提示词
|
||||
│ ├── extract_temporal.jinja2 # 时间信息提取提示词
|
||||
│ ├── extract_triplet.jinja2 # 三元组提取提示词
|
||||
│ ├── memory_summary.jinja2 # 记忆摘要提示词
|
||||
│ ├── evaluate.jinja2 # 评估提示词
|
||||
│ ├── reflexion.jinja2 # 反思提示词
|
||||
│ ├── system.jinja2 # 系统提示词
|
||||
│ └── user.jinja2 # 用户提示词
|
||||
├── llm/ # LLM 工具模块
|
||||
│ ├── __init__.py # LLM 模块初始化
|
||||
│ └── llm_utils.py # LLM 客户端工具
|
||||
├── data/ # 数据处理模块
|
||||
│ ├── __init__.py # 数据模块初始化
|
||||
│ ├── text_utils.py # 文本处理工具
|
||||
│ ├── time_utils.py # 时间处理工具
|
||||
│ └── ontology.py # 本体定义(谓语、标签等)
|
||||
├── paths/ # 路径管理模块
|
||||
│ ├── __init__.py # 路径模块初始化
|
||||
│ └── output_paths.py # 输出路径管理
|
||||
├── visualization/ # 可视化模块
|
||||
│ ├── __init__.py # 可视化模块初始化
|
||||
│ └── forgetting_visualizer.py # 遗忘曲线可视化
|
||||
└── self_reflexion_utils/ # 自我反思工具模块
|
||||
├── __init__.py # 反思模块初始化
|
||||
├── evaluate.py # 冲突评估
|
||||
├── reflexion.py # 反思处理
|
||||
└── self_reflexion.py # 自我反思主逻辑
|
||||
```
|
||||
|
||||
## 模块分类
|
||||
|
||||
### 1. 配置管理(config/)
|
||||
|
||||
配置管理模块包含所有与配置相关的工具函数和定义。
|
||||
|
||||
#### config_utils.py
|
||||
提供配置加载和管理功能:
|
||||
- `get_model_config(model_id)` - 获取 LLM 模型配置
|
||||
- `get_embedder_config(embedding_id)` - 获取嵌入模型配置
|
||||
- `get_neo4j_config()` - 获取 Neo4j 数据库配置
|
||||
- `get_chunker_config(chunker_strategy)` - 获取分块策略配置
|
||||
- `get_pipeline_config()` - 获取流水线配置
|
||||
- `get_pruning_config()` - 获取语义剪枝配置
|
||||
- `get_picture_config()` - 获取图片模型配置
|
||||
- `get_voice_config()` - 获取语音模型配置
|
||||
|
||||
#### definitions.py
|
||||
全局定义和常量:
|
||||
- `CONFIG` - 基础配置(从 config.json 加载)
|
||||
- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载)
|
||||
- `PROJECT_ROOT` - 项目根目录路径
|
||||
- 各种选择配置常量(LLM、嵌入模型、分块策略等)
|
||||
- `reload_configuration_from_database(config_id)` - 动态重新加载配置
|
||||
|
||||
#### overrides.py
|
||||
运行时配置覆写:
|
||||
- `load_unified_config(project_root)` - 加载统一配置
|
||||
|
||||
#### get_data.py
|
||||
数据获取工具:
|
||||
- `get_data(host_id)` - 从 SQL 数据库获取数据
|
||||
|
||||
#### litellm_config.py
|
||||
LiteLLM 配置和监控:
|
||||
- `LiteLLMConfig` - LiteLLM 配置类
|
||||
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
|
||||
- `get_usage_summary()` - 获取使用统计摘要
|
||||
- `print_usage_summary()` - 打印使用统计
|
||||
- `get_instant_qps(module)` - 获取即时 QPS 数据
|
||||
- `print_instant_qps(module)` - 打印即时 QPS 信息
|
||||
|
||||
#### config_optimization.py
|
||||
配置优化工具:
|
||||
- 配置参数优化相关功能
|
||||
|
||||
### 3. LLM 工具(llm/)
|
||||
|
||||
LLM 工具模块包含所有与 LLM 客户端相关的工具函数。
|
||||
|
||||
#### llm_utils.py
|
||||
LLM 客户端工具:
|
||||
- `get_llm_client(llm_id)` - 获取 LLM 客户端实例
|
||||
- `get_reranker_client(rerank_id)` - 获取重排序客户端实例
|
||||
- `handle_response(response)` - 处理 LLM 响应
|
||||
|
||||
#### litellm_config.py
|
||||
LiteLLM 配置和监控:
|
||||
- `LiteLLMConfig` - LiteLLM 配置类
|
||||
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
|
||||
- `get_usage_summary()` - 获取使用统计摘要
|
||||
- `print_usage_summary()` - 打印使用统计
|
||||
- `get_instant_qps(module)` - 获取即时 QPS 数据
|
||||
- `print_instant_qps(module)` - 打印即时 QPS 信息
|
||||
|
||||
### 4. 提示词管理(prompt/)
|
||||
|
||||
提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。
|
||||
|
||||
#### prompt_utils.py
|
||||
提示词渲染工具(使用 Jinja2 模板):
|
||||
- `get_prompts(message)` - 获取系统和用户提示词
|
||||
- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词
|
||||
- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词
|
||||
- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词
|
||||
- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词
|
||||
- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词
|
||||
- `prompt_env` - Jinja2 环境对象
|
||||
|
||||
#### template_render.py
|
||||
模板渲染工具(用于评估和反思):
|
||||
- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词
|
||||
- `render_reflexion_prompt(data, schema)` - 渲染反思提示词
|
||||
|
||||
#### prompts/
|
||||
Jinja2 模板文件目录,包含所有提示词模板
|
||||
|
||||
### 5. 数据处理(data/)
|
||||
|
||||
数据处理模块包含所有数据处理相关的工具函数。
|
||||
|
||||
#### text_utils.py
|
||||
文本处理工具:
|
||||
- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符
|
||||
- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询
|
||||
|
||||
#### time_utils.py
|
||||
时间处理工具:
|
||||
- `validate_date_format(date_str)` - 验证日期格式(YYYY-MM-DD)
|
||||
- `normalize_date(date_str)` - 标准化日期格式
|
||||
- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值)
|
||||
- `preprocess_date_string(date_str)` - 预处理日期字符串
|
||||
|
||||
#### ontology.py
|
||||
本体定义:
|
||||
- `PREDICATE_DEFINITIONS` - 谓语定义字典
|
||||
- `LABEL_DEFINITIONS` - 标签定义字典
|
||||
- `Predicate` - 谓语枚举
|
||||
- `StatementType` - 陈述句类型枚举
|
||||
- `TemporalInfo` - 时间信息枚举
|
||||
- `RelevenceInfo` - 相关性信息枚举
|
||||
|
||||
### 2. 日志管理(log/)
|
||||
|
||||
日志管理模块包含所有与日志记录相关的工具函数。
|
||||
|
||||
#### logging_utils.py
|
||||
日志工具:
|
||||
- `log_prompt_rendering(role, content)` - 记录提示词渲染
|
||||
- `log_template_rendering(template_name, context)` - 记录模板渲染
|
||||
- `log_time(operation, duration)` - 记录操作耗时
|
||||
- `prompt_logger` - 提示词日志记录器
|
||||
|
||||
#### audit_logger.py
|
||||
审计日志:
|
||||
- `audit_logger` - 审计日志记录器
|
||||
- 记录系统关键操作和安全事件
|
||||
|
||||
### 6. 自我反思工具(self_reflexion_utils/)
|
||||
|
||||
自我反思工具模块包含记忆冲突检测和反思处理功能。
|
||||
|
||||
#### evaluate.py
|
||||
冲突评估:
|
||||
- `conflict(evaluate_data, schema)` - 评估记忆冲突
|
||||
|
||||
#### reflexion.py
|
||||
反思处理:
|
||||
- `reflexion(data, schema)` - 执行反思处理
|
||||
|
||||
#### self_reflexion.py
|
||||
自我反思主逻辑:
|
||||
- `self_reflexion(...)` - 自我反思主函数
|
||||
|
||||
### 7. 数据模型
|
||||
|
||||
#### json_schema.py
|
||||
JSON Schema 数据模型:
|
||||
- `BaseDataSchema` - 基础数据模型
|
||||
- `ConflictResultSchema` - 冲突结果模型
|
||||
- `ConflictSchema` - 冲突模型
|
||||
- `ReflexionSchema` - 反思模型
|
||||
- `ResolvedSchema` - 解决方案模型
|
||||
- `ReflexionResultSchema` - 反思结果模型
|
||||
|
||||
#### messages.py
|
||||
API 消息模型:
|
||||
- `ConfigKey` - 配置键模型
|
||||
- `ChunkerStrategy` - 分块策略枚举
|
||||
- `ConfigParams` - 配置参数模型
|
||||
- `ConfigParamsCreate` - 创建配置参数模型
|
||||
- `ConfigUpdate` - 更新配置模型
|
||||
- `ConfigUpdateExtracted` - 更新萃取引擎配置模型
|
||||
- `ConfigUpdateForget` - 更新遗忘引擎配置模型
|
||||
- `ConfigPilotRun` - 试运行配置模型
|
||||
- `ConfigFilter` - 配置过滤模型
|
||||
- `ApiResponse` - API 响应模型
|
||||
- `ok(msg, data)` - 成功响应构造函数
|
||||
- `fail(msg, error_code, data)` - 失败响应构造函数
|
||||
|
||||
### 8. 可视化(visualization/)
|
||||
|
||||
可视化模块包含所有可视化相关的工具函数。
|
||||
|
||||
#### forgetting_visualizer.py
|
||||
遗忘曲线可视化:
|
||||
- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组
|
||||
- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线
|
||||
- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果
|
||||
- `visualize_forgetting_curve(...)` - 可视化遗忘曲线
|
||||
- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面
|
||||
- `create_comparison_visualization(...)` - 创建对比可视化
|
||||
- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件
|
||||
|
||||
### 9. 路径管理(paths/)
|
||||
|
||||
路径管理模块包含所有路径管理相关的工具函数。
|
||||
|
||||
#### output_paths.py
|
||||
输出路径管理:
|
||||
- `get_output_dir()` - 获取输出目录
|
||||
- `get_output_path(filename)` - 获取输出文件路径
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 配置管理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.config import get_model_config, get_pipeline_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
|
||||
# 获取模型配置
|
||||
model_config = get_model_config("model_id_123")
|
||||
|
||||
# 获取流水线配置
|
||||
pipeline_config = get_pipeline_config()
|
||||
|
||||
# 使用全局常量
|
||||
llm_id = SELECTED_LLM_ID
|
||||
```
|
||||
|
||||
### 日志管理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger
|
||||
|
||||
# 记录提示词渲染
|
||||
log_prompt_rendering('user', 'Hello, world!')
|
||||
|
||||
# 记录操作耗时
|
||||
log_time('extraction', 1.23)
|
||||
|
||||
# 使用审计日志
|
||||
audit_logger.info('User action performed')
|
||||
```
|
||||
|
||||
### LLM 工具
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.llm import get_llm_client
|
||||
|
||||
# 获取 LLM 客户端
|
||||
llm_client = get_llm_client("llm_id_456")
|
||||
|
||||
# 调用 LLM
|
||||
response = await llm_client.chat([
|
||||
{"role": "user", "content": "Hello"}
|
||||
])
|
||||
```
|
||||
|
||||
### 提示词渲染
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
|
||||
|
||||
# 渲染陈述句提取提示词
|
||||
prompt = await render_statement_extraction_prompt(
|
||||
chunk_content="对话内容...",
|
||||
definitions=LABEL_DEFINITIONS,
|
||||
json_schema=schema,
|
||||
granularity=2
|
||||
)
|
||||
```
|
||||
|
||||
### 数据处理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.data.time_utils import normalize_date
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
# 标准化日期
|
||||
normalized = normalize_date("2025/10/28") # 返回 "2025-10-28"
|
||||
|
||||
# 转义 Lucene 查询
|
||||
escaped = escape_lucene_query("user:admin AND status:active")
|
||||
```
|
||||
|
||||
### 运行时配置覆写
|
||||
|
||||
```python
|
||||
from app.core.memory.utils import apply_runtime_overrides_with_config_id
|
||||
|
||||
# 使用指定 config_id 覆写配置
|
||||
runtime_cfg = {"selections": {}}
|
||||
updated_cfg = apply_runtime_overrides_with_config_id(
|
||||
project_root="/path/to/project",
|
||||
runtime_cfg=runtime_cfg,
|
||||
config_id="config_123"
|
||||
)
|
||||
```
|
||||
|
||||
## 迁移说明
|
||||
|
||||
### 从旧路径迁移
|
||||
|
||||
如果你的代码使用了旧的导入路径,请按以下方式更新:
|
||||
|
||||
**旧路径(2024年11月之前):**
|
||||
```python
|
||||
from app.core.memory.src.utils.config_utils import get_model_config
|
||||
from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.src.data_config_api.utils.messages import ok, fail
|
||||
```
|
||||
|
||||
**中间路径(2024年11月):**
|
||||
```python
|
||||
from app.core.memory.utils.config_utils import get_model_config
|
||||
from app.core.memory.utils.logging_utils import log_prompt_rendering
|
||||
from app.schemas.memory_storage_schema import ok, fail
|
||||
```
|
||||
|
||||
**新路径(2024年11月27日之后):**
|
||||
```python
|
||||
# 配置相关
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import get_model_config # 简化导入
|
||||
|
||||
# 日志相关
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.log import log_prompt_rendering # 简化导入
|
||||
|
||||
# 其他工具
|
||||
from app.core.memory.utils import prompt_utils
|
||||
from app.schemas.memory_storage_schema import ok, fail
|
||||
```
|
||||
|
||||
### 目录结构重组(2024年11月27日)
|
||||
|
||||
utils 目录已按功能进行了完整的重组:
|
||||
|
||||
**重组前的结构:**
|
||||
- 所有文件都在 `app/core/memory/utils/` 根目录下
|
||||
|
||||
**重组后的结构:**
|
||||
- `config/` - 配置管理相关文件
|
||||
- `log/` - 日志管理相关文件
|
||||
- `prompt/` - 提示词管理相关文件
|
||||
- `llm/` - LLM 工具相关文件
|
||||
- `data/` - 数据处理相关文件
|
||||
- `paths/` - 路径管理相关文件
|
||||
- `visualization/` - 可视化相关文件
|
||||
- `self_reflexion_utils/` - 自我反思工具(已存在)
|
||||
|
||||
**导入路径变化:**
|
||||
```python
|
||||
# 旧导入方式
|
||||
from app.core.memory.utils.config_utils import get_model_config
|
||||
from app.core.memory.utils.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
# 新导入方式
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
# 或使用简化导入
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.log import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
```
|
||||
|
||||
## 维护指南
|
||||
|
||||
### 添加新工具函数
|
||||
|
||||
1. 在相应的模块文件中添加函数
|
||||
2. 在 `__init__.py` 中导出函数
|
||||
3. 在本 README 中添加文档
|
||||
4. 编写单元测试
|
||||
|
||||
### 删除旧工具函数
|
||||
|
||||
1. 确认没有代码使用该函数
|
||||
2. 从模块文件中删除函数
|
||||
3. 从 `__init__.py` 中删除导出
|
||||
4. 更新本 README
|
||||
|
||||
### 重构工具函数
|
||||
|
||||
1. 保持向后兼容性(使用别名或包装器)
|
||||
2. 更新所有使用该函数的代码
|
||||
3. 更新文档和测试
|
||||
4. 在适当时机删除旧版本
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码
|
||||
2. **文档完整性**:每个函数都应有清晰的文档字符串
|
||||
3. **类型注解**:使用类型注解提高代码可读性
|
||||
4. **错误处理**:工具函数应有适当的错误处理
|
||||
5. **测试覆盖**:所有工具函数都应有单元测试
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md)
|
||||
- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md)
|
||||
- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md)
|
||||
@@ -6,33 +6,27 @@
|
||||
|
||||
# 从子模块导出常用函数和常量,保持向后兼容
|
||||
from .config_utils import (
|
||||
get_model_config,
|
||||
get_embedder_config,
|
||||
get_neo4j_config,
|
||||
get_chunker_config,
|
||||
get_embedder_config,
|
||||
get_model_config,
|
||||
get_neo4j_config,
|
||||
get_picture_config,
|
||||
get_pipeline_config,
|
||||
get_pruning_config,
|
||||
get_picture_config,
|
||||
get_voice_config,
|
||||
)
|
||||
from .definitions import (
|
||||
CONFIG,
|
||||
RUNTIME_CONFIG,
|
||||
PROJECT_ROOT,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_RERANK_ID,
|
||||
SELECTED_LLM_PICTURE_NAME,
|
||||
SELECTED_LLM_VOICE_NAME,
|
||||
REFLEXION_ENABLED,
|
||||
REFLEXION_ITERATION_PERIOD,
|
||||
REFLEXION_RANGE,
|
||||
REFLEXION_BASELINE,
|
||||
reload_configuration_from_database,
|
||||
)
|
||||
from .overrides import load_unified_config
|
||||
|
||||
# DEPRECATED: Global configuration variables removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
# from .definitions import (
|
||||
# CONFIG, # DEPRECATED - empty dict for backward compatibility
|
||||
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
|
||||
# PROJECT_ROOT, # Still needed for file paths
|
||||
# reload_configuration_from_database, # DEPRECATED - returns False
|
||||
# )
|
||||
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
|
||||
from .get_data import get_data
|
||||
|
||||
# litellm_config 需要时动态导入,避免循环依赖
|
||||
# from .litellm_config import (
|
||||
# LiteLLMConfig,
|
||||
@@ -53,23 +47,11 @@ __all__ = [
|
||||
"get_pruning_config",
|
||||
"get_picture_config",
|
||||
"get_voice_config",
|
||||
# definitions
|
||||
"CONFIG",
|
||||
"RUNTIME_CONFIG",
|
||||
"PROJECT_ROOT",
|
||||
"SELECTED_LLM_ID",
|
||||
"SELECTED_EMBEDDING_ID",
|
||||
"SELECTED_GROUP_ID",
|
||||
"SELECTED_RERANK_ID",
|
||||
"SELECTED_LLM_PICTURE_NAME",
|
||||
"SELECTED_LLM_VOICE_NAME",
|
||||
"REFLEXION_ENABLED",
|
||||
"REFLEXION_ITERATION_PERIOD",
|
||||
"REFLEXION_RANGE",
|
||||
"REFLEXION_BASELINE",
|
||||
"reload_configuration_from_database",
|
||||
# overrides
|
||||
"load_unified_config",
|
||||
# definitions (DEPRECATED - use MemoryConfig objects instead)
|
||||
# "CONFIG", # DEPRECATED
|
||||
# "RUNTIME_CONFIG", # DEPRECATED
|
||||
# "PROJECT_ROOT",
|
||||
# "reload_configuration_from_database", # DEPRECATED
|
||||
# get_data
|
||||
"get_data",
|
||||
# litellm_config - 需要时从 .litellm_config 直接导入
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
DedupConfig,
|
||||
StatementExtractionConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.definitions import CONFIG
|
||||
from app.db import get_db
|
||||
from app.models.models_model import ModelConfig, ModelApiKey
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def get_model_config(model_id: str, db: Session | None = None) -> dict:
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
@@ -110,6 +106,13 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
|
||||
# 2) Provide sane defaults for newer strategies
|
||||
default_configs = {
|
||||
"RecursiveChunker": {
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
|
||||
"LLMChunker": {
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
@@ -147,94 +150,74 @@ def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
|
||||
)
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
def get_pipeline_config() -> ExtractionPipelineConfig:
|
||||
"""Build ExtractionPipelineConfig using only runtime.json values.
|
||||
def get_pipeline_config(
|
||||
config_id: int,
|
||||
db: Session | None = None,
|
||||
) -> ExtractionPipelineConfig:
|
||||
"""Build ExtractionPipelineConfig from database.
|
||||
|
||||
Behavior:
|
||||
- Read `deduplication` section from runtime.json if present.
|
||||
- Read `statement_extraction` section from runtime.json if present.
|
||||
- Read `forgetting_engine` section from runtime.json if present.
|
||||
- If absent, check legacy top-level `enable_llm_dedup` key.
|
||||
- Do NOT fall back to environment variables.
|
||||
- Unspecified fields use model defaults defined in DedupConfig.
|
||||
Args:
|
||||
config_id: Database configuration ID (required). Loads pipeline
|
||||
settings from the DataConfig table.
|
||||
db: Optional database session. If not provided, a new session
|
||||
will be created.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings loaded from database.
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id not found in database.
|
||||
"""
|
||||
dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {}
|
||||
stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {}
|
||||
forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {}
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
# Assemble kwargs from runtime.json only
|
||||
kwargs = {}
|
||||
# LLM switch: prefer new key, then legacy top-level, default False
|
||||
if "enable_llm_dedup_blockwise" in dedup_rc:
|
||||
kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise"))
|
||||
else:
|
||||
# Legacy top-level fallback inside runtime.json only
|
||||
legacy = RUNTIME_CONFIG.get("enable_llm_dedup")
|
||||
if legacy is not None:
|
||||
kwargs["enable_llm_dedup_blockwise"] = bool(legacy)
|
||||
else:
|
||||
kwargs["enable_llm_dedup_blockwise"] = False # default reserve
|
||||
# Disambiguation switch: only from runtime.json deduplication section
|
||||
if "enable_llm_disambiguation" in dedup_rc:
|
||||
kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation"))
|
||||
# Load from database
|
||||
if db is None:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
db_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"Configuration {config_id} not found in database")
|
||||
|
||||
# Optional LLM fallback gating
|
||||
if "enable_llm_fallback_only_on_borderline" in dedup_rc:
|
||||
kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline"))
|
||||
# Build DedupConfig from database
|
||||
dedup_kwargs = {
|
||||
"enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False,
|
||||
"enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False,
|
||||
}
|
||||
|
||||
# Fuzzy thresholds
|
||||
if db_config.t_name_strict is not None:
|
||||
dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict
|
||||
if db_config.t_type_strict is not None:
|
||||
dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict
|
||||
if db_config.t_overall is not None:
|
||||
dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall
|
||||
|
||||
# Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults
|
||||
for key in (
|
||||
"fuzzy_name_threshold_strict",
|
||||
"fuzzy_type_threshold_strict",
|
||||
"fuzzy_overall_threshold",
|
||||
"fuzzy_unknown_type_name_threshold",
|
||||
"fuzzy_unknown_type_type_threshold",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
dedup_config = DedupConfig(**dedup_kwargs)
|
||||
|
||||
# Optional weights and bonuses for overall scoring
|
||||
for key in (
|
||||
"name_weight",
|
||||
"desc_weight",
|
||||
"type_weight",
|
||||
"context_bonus",
|
||||
"llm_fallback_floor",
|
||||
"llm_fallback_ceiling",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
# Optional LLM iterative dedup parameters
|
||||
for key in (
|
||||
"llm_block_size",
|
||||
"llm_block_concurrency",
|
||||
"llm_pair_concurrency",
|
||||
"llm_max_rounds",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
dedup_config = DedupConfig(**kwargs)
|
||||
|
||||
# Build StatementExtractionConfig from runtime.json
|
||||
# Build StatementExtractionConfig from database
|
||||
stmt_kwargs = {}
|
||||
for key in (
|
||||
"statement_granularity",
|
||||
"temperature",
|
||||
"include_dialogue_context",
|
||||
"max_dialogue_context_chars",
|
||||
):
|
||||
if key in stmt_rc:
|
||||
stmt_kwargs[key] = stmt_rc[key]
|
||||
if db_config.statement_granularity is not None:
|
||||
stmt_kwargs["statement_granularity"] = db_config.statement_granularity
|
||||
if db_config.include_dialogue_context is not None:
|
||||
stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context)
|
||||
if db_config.max_context is not None:
|
||||
stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context
|
||||
|
||||
stmt_config = StatementExtractionConfig(**stmt_kwargs)
|
||||
|
||||
# Build ForgettingEngineConfig from runtime.json
|
||||
# Build ForgettingEngineConfig from database
|
||||
forget_kwargs = {}
|
||||
for key in ("offset", "lambda_time", "lambda_mem"):
|
||||
if key in forget_rc:
|
||||
forget_kwargs[key] = forget_rc[key]
|
||||
if db_config.offset is not None:
|
||||
forget_kwargs["offset"] = db_config.offset
|
||||
if db_config.lambda_time is not None:
|
||||
forget_kwargs["lambda_time"] = db_config.lambda_time
|
||||
if db_config.lambda_mem is not None:
|
||||
forget_kwargs["lambda_mem"] = db_config.lambda_mem
|
||||
|
||||
forget_config = ForgettingEngineConfig(**forget_kwargs)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
@@ -244,24 +227,37 @@ def get_pipeline_config() -> ExtractionPipelineConfig:
|
||||
)
|
||||
|
||||
|
||||
def get_pruning_config() -> dict:
|
||||
"""Retrieve semantic pruning config from runtime.json.
|
||||
def get_pruning_config(
|
||||
config_id: int,
|
||||
db: Session | None = None,
|
||||
) -> dict:
|
||||
"""Retrieve semantic pruning config from database.
|
||||
|
||||
Returns a dict suitable for PruningConfig.model_validate.
|
||||
Args:
|
||||
config_id: Database configuration ID (required).
|
||||
db: Optional database session.
|
||||
|
||||
Structure in runtime.json:
|
||||
{
|
||||
"pruning": {
|
||||
"enabled": true,
|
||||
"scene": "education" | "online_service" | "outbound",
|
||||
"threshold": 0.5
|
||||
}
|
||||
}
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str ("education" | "online_service" | "outbound")
|
||||
- pruning_threshold: float (0-0.9)
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id not found in database.
|
||||
"""
|
||||
pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {}
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
if db is None:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
db_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"Configuration {config_id} not found in database")
|
||||
|
||||
return {
|
||||
"pruning_switch": bool(pruning_rc.get("enabled", False)),
|
||||
"pruning_scene": pruning_rc.get("scene", "education"),
|
||||
"pruning_threshold": float(pruning_rc.get("threshold", 0.5)),
|
||||
"pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False,
|
||||
"pruning_scene": db_config.pruning_scene or "education",
|
||||
"pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5,
|
||||
}
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
"""
|
||||
配置加载模块 - 三阶段架构(已迁移到统一配置管理)
|
||||
配置加载模块 - DEPRECATED
|
||||
|
||||
本模块现在使用全局配置管理系统 (app/core/config.py)
|
||||
来加载和管理配置,同时保持向后兼容性。
|
||||
⚠️ DEPRECATION NOTICE ⚠️
|
||||
This module is deprecated and will be removed in a future version.
|
||||
Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)
|
||||
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)
|
||||
Use the new MemoryConfig system instead:
|
||||
- app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
@@ -35,21 +43,12 @@ except ImportError:
|
||||
# os.path.dirname(...) = app/core/memory
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# 全局配置锁 - 用于线程安全
|
||||
_config_lock = threading.RLock()
|
||||
# DEPRECATED: Global configuration lock removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
|
||||
# 加载基础配置(config.json)- 使用全局配置系统
|
||||
if USE_UNIFIED_CONFIG:
|
||||
CONFIG = settings.load_memory_config()
|
||||
else:
|
||||
# Fallback to legacy loading
|
||||
config_path = os.path.join(PROJECT_ROOT, "config.json")
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
CONFIG = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
print("Warning: config.json not found or is malformed. Using default settings.")
|
||||
CONFIG = {}
|
||||
# DEPRECATED: Legacy config.json loading removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
CONFIG = {}
|
||||
|
||||
DEFAULT_VALUES = {
|
||||
"llm_name": "openai/qwen-plus",
|
||||
@@ -68,35 +67,31 @@ DEFAULT_VALUES = {
|
||||
"reflexion_baseline": "TIME",
|
||||
}
|
||||
|
||||
# DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# These will be removed in a future version
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
"""
|
||||
从 runtime.json 文件加载配置(通过统一配置加载器)
|
||||
DEPRECATED: Legacy runtime.json loading
|
||||
|
||||
使用 overrides.py 的统一配置加载器,按优先级加载:
|
||||
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id)
|
||||
2. 环境变量配置
|
||||
3. runtime.json 默认配置
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 运行时配置字典
|
||||
Dict[str, Any]: Empty configuration (legacy support only)
|
||||
"""
|
||||
try:
|
||||
# 使用 overrides.py 的统一配置加载器
|
||||
from app.core.memory.utils.config.overrides import load_unified_config
|
||||
|
||||
runtime_cfg = load_unified_config(PROJECT_ROOT)
|
||||
return runtime_cfg
|
||||
except Exception as e:
|
||||
# Fallback: 直接读取 runtime.json
|
||||
runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json")
|
||||
try:
|
||||
with open(runtime_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e2:
|
||||
pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置")
|
||||
return {"selections": {}}
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return {"selections": {}}
|
||||
|
||||
|
||||
# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
@@ -104,207 +99,116 @@ def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
# 保留此函数仅为向后兼容
|
||||
def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从数据库加载配置(基于 dbrun.json 中的 config_id)
|
||||
DEPRECATED: Legacy database configuration loading
|
||||
|
||||
注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json
|
||||
即可获得包含数据库配置的完整配置。
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典
|
||||
Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
"""
|
||||
try:
|
||||
# 直接使用统一配置加载器
|
||||
return _load_from_runtime_json()
|
||||
except Exception:
|
||||
return None
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)
|
||||
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
"""
|
||||
将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
这是路径 A(runtime.json)和路径 B(数据库)的汇合点,
|
||||
无论配置来自哪里,都通过这个函数统一暴露为常量。
|
||||
DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
Use the new MemoryConfig system instead:
|
||||
- app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
- Pass configuration objects as parameters instead of using global variables
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
"""
|
||||
global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG
|
||||
global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME
|
||||
global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY
|
||||
global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES
|
||||
global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID
|
||||
global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE
|
||||
|
||||
RUNTIME_CONFIG = runtime_cfg
|
||||
|
||||
# 可观测性配置
|
||||
LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False)
|
||||
AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False)
|
||||
|
||||
# 日志配置
|
||||
LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {})
|
||||
PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"])
|
||||
|
||||
# 选择配置
|
||||
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# 基础模型选择
|
||||
SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"])
|
||||
SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"])
|
||||
SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"])
|
||||
|
||||
# 分组和用户配置
|
||||
SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"])
|
||||
SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"])
|
||||
SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"])
|
||||
SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None)
|
||||
|
||||
# 专用 LLM 配置
|
||||
SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"])
|
||||
SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"])
|
||||
SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"])
|
||||
SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"])
|
||||
|
||||
# 模型 ID 配置
|
||||
SELECTED_LLM_ID = SELECTIONS.get("llm_id", None)
|
||||
SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None)
|
||||
SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None)
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# 反思配置
|
||||
REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {})
|
||||
REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False)
|
||||
REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"])
|
||||
REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"])
|
||||
REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"])
|
||||
# Keep minimal global state for backward compatibility only
|
||||
# These will be removed in a future version
|
||||
global RUNTIME_CONFIG, SELECTIONS
|
||||
|
||||
RUNTIME_CONFIG = runtime_cfg
|
||||
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# All other global variables have been removed
|
||||
# Use MemoryConfig objects instead
|
||||
|
||||
|
||||
# 初始化:使用统一配置加载器
|
||||
def _initialize_configuration() -> None:
|
||||
"""
|
||||
初始化配置:使用统一配置加载器
|
||||
DEPRECATED: Legacy configuration initialization
|
||||
|
||||
配置加载优先级(由 overrides.py 统一处理):
|
||||
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id)
|
||||
2. 环境变量配置(.env)
|
||||
3. runtime.json 默认配置
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
"""
|
||||
try:
|
||||
|
||||
# 使用统一配置加载器(已包含所有优先级处理)
|
||||
runtime_config = _load_from_runtime_json()
|
||||
|
||||
# 暴露为全局常量
|
||||
_expose_runtime_constants(runtime_config)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
pass # print(f"[definitions] × 配置初始化失败: {e}")
|
||||
# 使用空配置
|
||||
_expose_runtime_constants({"selections": {}})
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Initialize with empty configuration for backward compatibility
|
||||
_expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# 模块加载时自动初始化配置
|
||||
_initialize_configuration()
|
||||
|
||||
# DEPRECATED: Global variables removed
|
||||
# These variables have been eliminated in favor of dependency injection
|
||||
# Use MemoryConfig objects instead of accessing global variables
|
||||
|
||||
|
||||
# 公共 API:动态重新加载配置
|
||||
def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool:
|
||||
def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
"""
|
||||
动态重新加载配置(从数据库)- 使用统一配置加载器
|
||||
用于运行时切换配置,例如前端传入新的 config_id 时调用。
|
||||
|
||||
注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。
|
||||
DEPRECATED: Legacy configuration reloading
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
For new code, use:
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID(整数或字符串,会自动转换)
|
||||
force_reload: 保留参数以保持向后兼容(已移除缓存逻辑)
|
||||
config_id: Configuration ID (deprecated)
|
||||
force_reload: Force reload flag (deprecated)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功重新加载配置
|
||||
bool: Always returns False (deprecated functionality)
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
warnings.warn(
|
||||
"reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
with _config_lock:
|
||||
try:
|
||||
from app.core.memory.utils.config.overrides import load_unified_config
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 导入统一配置加载器失败: {e}")
|
||||
|
||||
# 记录配置加载失败
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"error": f"Import failed: {str(e)}"}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"[definitions] 开始重新加载配置,config_id={config_id}")
|
||||
|
||||
# 使用统一配置加载器(指定 config_id)
|
||||
updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id)
|
||||
|
||||
# 检查是否成功加载
|
||||
if not updated_cfg or not updated_cfg.get('selections'):
|
||||
logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置")
|
||||
|
||||
# 记录配置加载失败
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"reason": "config not found in database"}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
# 重新暴露常量
|
||||
_expose_runtime_constants(updated_cfg)
|
||||
|
||||
logger.info("[definitions] 配置重新加载成功,已暴露常量")
|
||||
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
|
||||
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
|
||||
|
||||
# 记录成功的配置加载
|
||||
if audit_logger:
|
||||
selections = updated_cfg.get('selections', {})
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
user_id=selections.get('user_id', None),
|
||||
group_id=selections.get('group_id', None),
|
||||
success=True,
|
||||
details={
|
||||
"llm_id": selections.get('llm_id'),
|
||||
"embedding_id": selections.get('embedding_id'),
|
||||
"chunker_strategy": selections.get('chunker_strategy')
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True)
|
||||
|
||||
# 记录配置加载异常
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
return False
|
||||
logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
"Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -312,49 +216,54 @@ def reload_configuration_from_database(config_id: int | str, force_reload: bool
|
||||
|
||||
def get_current_config_id() -> Optional[str]:
|
||||
"""
|
||||
获取当前使用的 config_id
|
||||
DEPRECATED: Legacy config ID retrieval
|
||||
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
Returns:
|
||||
Optional[str]: 当前的 config_id,如果未设置则返回 None
|
||||
Optional[str]: None (deprecated functionality)
|
||||
"""
|
||||
return SELECTIONS.get("config_id", None)
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool:
|
||||
def ensure_fresh_config(config_id = None) -> bool:
|
||||
"""
|
||||
确保使用最新的配置(每次写入操作前调用)
|
||||
DEPRECATED: Legacy configuration freshness check
|
||||
|
||||
如果提供了 config_id,则加载该配置;
|
||||
否则从 dbrun.json 读取并加载最新配置。
|
||||
⚠️ This function is deprecated and will be removed in a future version.
|
||||
Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
For new code, use:
|
||||
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
Args:
|
||||
config_id: 可选的配置ID(整数或字符串,会自动转换)
|
||||
config_id: Configuration ID (deprecated)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加载配置
|
||||
bool: Always returns False (deprecated functionality)
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with _config_lock:
|
||||
try:
|
||||
if config_id:
|
||||
# 使用指定的 config_id
|
||||
logger.debug(f"[definitions] 加载指定配置,config_id={config_id}")
|
||||
return reload_configuration_from_database(config_id)
|
||||
else:
|
||||
# 从数据库重新加载配置
|
||||
logger.debug("[definitions] 从数据库重新加载最新配置")
|
||||
memory_config = _load_from_database()
|
||||
|
||||
if not memory_config or not memory_config.get('selections'):
|
||||
logger.warning("[definitions] 未能从数据库加载配置,使用当前配置")
|
||||
return False
|
||||
|
||||
_expose_memory_constants(memory_config)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True)
|
||||
return False
|
||||
warnings.warn(
|
||||
"ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
"Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
"""
|
||||
运行时配置覆写工具 - 统一配置加载器
|
||||
|
||||
本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。
|
||||
|
||||
配置来源优先级(从高到低):
|
||||
1. 数据库配置(PostgreSQL data_config 表)
|
||||
2. 环境变量配置(.env 文件)
|
||||
3. 默认配置(runtime.json 文件)
|
||||
|
||||
支持的配置加载方式:
|
||||
- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入)
|
||||
- 基于 group_id 的配置加载(从 dbrun.json 读取)
|
||||
- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式)
|
||||
|
||||
主要功能:
|
||||
- 从 PostgreSQL 数据库读取配置
|
||||
- 从环境变量读取配置
|
||||
- 从 runtime.json 读取默认配置
|
||||
- 按优先级覆写配置项(仅在内存中,不修改文件)
|
||||
- 支持多种配置字段:selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion
|
||||
|
||||
使用场景:
|
||||
- 应用启动时自动加载配置
|
||||
- 前端切换配置时动态重新加载
|
||||
- 多租户场景下的配置隔离
|
||||
- 内外网环境自动切换
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import socket
|
||||
from typing import Optional, Dict, Any, Literal
|
||||
|
||||
NetworkMode = Literal['internal', 'external']
|
||||
|
||||
|
||||
def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster):
|
||||
"""安全地设置目标字典的值(如果源字典中存在且不为 None)
|
||||
|
||||
Args:
|
||||
target: 目标字典
|
||||
target_key: 目标字典的键
|
||||
src: 源字典
|
||||
src_key: 源字典的键
|
||||
caster: 类型转换函数
|
||||
"""
|
||||
try:
|
||||
if src_key in src and src.get(src_key) is not None:
|
||||
try:
|
||||
target[target_key] = caster(src.get(src_key))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _to_bool(val: Any) -> bool:
|
||||
"""将各种类型的值转换为布尔值
|
||||
|
||||
支持的输入:
|
||||
- bool: 直接返回
|
||||
- int/float: 非零为 True
|
||||
- str: "true", "1", "on", "yes" 为 True;"false", "0", "off", "no" 为 False
|
||||
|
||||
Args:
|
||||
val: 要转换的值
|
||||
|
||||
Returns:
|
||||
bool: 转换后的布尔值
|
||||
"""
|
||||
try:
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
if isinstance(val, (int, float)):
|
||||
return bool(val)
|
||||
if isinstance(val, str):
|
||||
m = val.strip().lower()
|
||||
if m in {"true", "1", "on", "yes"}:
|
||||
return True
|
||||
if m in {"false", "0", "off", "no"}:
|
||||
return False
|
||||
return bool(val)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _make_pgsql_conn() -> Optional[object]:
|
||||
"""创建 PostgreSQL 数据库连接
|
||||
|
||||
使用环境变量配置连接参数:
|
||||
- DB_HOST: 数据库主机地址(默认 localhost)
|
||||
- DB_PORT: 数据库端口(默认 5432)
|
||||
- DB_USER: 数据库用户名
|
||||
- DB_PASSWORD: 数据库密码
|
||||
- DB_NAME: 数据库名称
|
||||
|
||||
Returns:
|
||||
Optional[object]: 数据库连接对象,失败时返回 None
|
||||
"""
|
||||
host = os.getenv("DB_HOST", "localhost")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
dbname = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
return conn
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 group_id 从数据库查询配置
|
||||
|
||||
Args:
|
||||
group_id: 组标识符
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
|
||||
"""
|
||||
conn = _make_pgsql_conn()
|
||||
if conn is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
cur = conn.cursor(cursor_factory=RealDictCursor)
|
||||
|
||||
try:
|
||||
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sql = (
|
||||
"SELECT group_id, user_id, apply_id, chunker_strategy, "
|
||||
" enable_llm_dedup_blockwise, enable_llm_disambiguation "
|
||||
"FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1"
|
||||
)
|
||||
cur.execute(sql, (group_id,))
|
||||
row = cur.fetchone()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 config_id 从数据库查询配置
|
||||
|
||||
Args:
|
||||
config_id: 配置标识符(整数或字符串,会自动转换为整数)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
|
||||
"""
|
||||
conn = _make_pgsql_conn()
|
||||
if conn is None:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
try:
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
cur = conn.cursor(cursor_factory=RealDictCursor)
|
||||
|
||||
try:
|
||||
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# config_id 在数据库中是 Integer 类型,需要转换
|
||||
try:
|
||||
config_id_int = int(config_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
sql = (
|
||||
"SELECT config_id, group_id, user_id, apply_id, chunker_strategy, "
|
||||
" enable_llm_dedup_blockwise, enable_llm_disambiguation, "
|
||||
" deep_retrieval, t_type_strict, t_name_strict, t_overall, state, "
|
||||
" statement_granularity, include_dialogue_context, max_context, "
|
||||
" \"offset\" AS offset, lambda_time, lambda_mem, "
|
||||
" pruning_enabled, pruning_scene, pruning_threshold, "
|
||||
" llm_id, embedding_id "
|
||||
"FROM data_config WHERE config_id = %s LIMIT 1"
|
||||
)
|
||||
cur.execute(sql, (config_id_int,))
|
||||
row = cur.fetchone()
|
||||
|
||||
if row:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return row if row else None
|
||||
except Exception as e:
|
||||
pass
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_dbrun_group_id(project_root: str) -> Optional[str]:
|
||||
"""从 dbrun.json 读取 group_id
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: group_id,未找到时返回 None
|
||||
"""
|
||||
try:
|
||||
path = os.path.join(project_root, "dbrun.json")
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if "group_id" in data:
|
||||
return str(data.get("group_id"))
|
||||
sel = data.get("selections", {})
|
||||
if isinstance(sel, dict) and "group_id" in sel:
|
||||
return str(sel.get("group_id"))
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _load_dbrun_config_id(project_root: str) -> Optional[str]:
|
||||
"""从 dbrun.json 读取 config_id
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: config_id,未找到时返回 None
|
||||
"""
|
||||
try:
|
||||
path = os.path.join(project_root, "dbrun.json")
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if "config_id" in data:
|
||||
return str(data.get("config_id"))
|
||||
sel = data.get("selections", {})
|
||||
if isinstance(sel, dict) and "config_id" in sel:
|
||||
return str(sel.get("config_id"))
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _apply_overrides_from_db_row(
|
||||
runtime_cfg: Dict[str, Any],
|
||||
db_row: Optional[Dict[str, Any]],
|
||||
identifier: str,
|
||||
identifier_type: str = "config_id"
|
||||
) -> Dict[str, Any]:
|
||||
"""从数据库行数据覆写运行时配置(统一处理函数)
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
db_row: 数据库查询结果行
|
||||
identifier: 标识符值(group_id 或 config_id)
|
||||
identifier_type: 标识符类型("group_id" 或 "config_id")
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selections = runtime_cfg.setdefault("selections", {})
|
||||
selections[identifier_type] = identifier
|
||||
|
||||
if not db_row:
|
||||
return runtime_cfg
|
||||
|
||||
# 覆写 selections 字段
|
||||
for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state",
|
||||
"t_type_strict", "t_name_strict", "t_overall",
|
||||
"statement_granularity", "include_dialogue_context"):
|
||||
_set_if_present(selections, tk, db_row, tk, str)
|
||||
|
||||
# 特殊处理 UUID 字段,确保转换为字符串格式
|
||||
for uuid_field in ("llm_id", "embedding_id"):
|
||||
if uuid_field in db_row and db_row.get(uuid_field) is not None:
|
||||
try:
|
||||
value = db_row.get(uuid_field)
|
||||
# 如果是 UUID 对象,转换为字符串(带连字符的标准格式)
|
||||
if hasattr(value, 'hex'):
|
||||
selections[uuid_field] = str(value)
|
||||
else:
|
||||
selections[uuid_field] = str(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 覆写 statement_extraction 字段
|
||||
stmt = runtime_cfg.setdefault("statement_extraction", {})
|
||||
_set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int)
|
||||
_set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool)
|
||||
_set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int)
|
||||
|
||||
# 覆写 deduplication 字段
|
||||
dedup = runtime_cfg.setdefault("deduplication", {})
|
||||
for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"):
|
||||
_set_if_present(dedup, tk, db_row, tk, _to_bool)
|
||||
_set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool)
|
||||
|
||||
# 覆写 forgetting_engine 字段
|
||||
forgetting = runtime_cfg.setdefault("forgetting_engine", {})
|
||||
_set_if_present(forgetting, "offset", db_row, "offset", float)
|
||||
_set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float)
|
||||
_set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float)
|
||||
|
||||
# 覆写 pruning 字段
|
||||
pruning = runtime_cfg.setdefault("pruning", {})
|
||||
_set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool)
|
||||
_set_if_present(pruning, "scene", db_row, "pruning_scene", str)
|
||||
|
||||
# 阈值需要转为 float,且限制在 [0.0, 0.9]
|
||||
try:
|
||||
if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None:
|
||||
thr = float(db_row.get("pruning_threshold"))
|
||||
thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9]
|
||||
pruning["threshold"] = thr
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return runtime_cfg
|
||||
except Exception as e:
|
||||
pass
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""基于 group_id 从数据库覆写运行时配置
|
||||
|
||||
工作流程:
|
||||
1. 从 dbrun.json 读取 group_id
|
||||
2. 根据 group_id 查询数据库配置
|
||||
3. 覆写运行时配置(仅在内存中)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selected_gid = _load_dbrun_group_id(project_root)
|
||||
if not selected_gid:
|
||||
return runtime_cfg
|
||||
|
||||
db_row = _fetch_db_config_by_group_id(selected_gid)
|
||||
if not db_row:
|
||||
# 如果数据库中没有配置,仍然设置 group_id
|
||||
runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid
|
||||
return runtime_cfg
|
||||
|
||||
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id")
|
||||
except Exception:
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取)
|
||||
|
||||
工作流程:
|
||||
1. 从 dbrun.json 读取 config_id
|
||||
2. 根据 config_id 查询数据库配置
|
||||
3. 覆写运行时配置(仅在内存中)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selected_cid = _load_dbrun_config_id(project_root)
|
||||
if not selected_cid:
|
||||
return runtime_cfg
|
||||
|
||||
db_row = _fetch_db_config_by_config_id(selected_cid)
|
||||
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
|
||||
except Exception:
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_with_config_id(
|
||||
project_root: str,
|
||||
runtime_cfg: Dict[str, Any],
|
||||
config_id: str
|
||||
) -> tuple[Dict[str, Any], bool]:
|
||||
"""使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json)
|
||||
|
||||
用于前端动态切换配置的场景。
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
config_id: 配置标识符
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载)
|
||||
"""
|
||||
try:
|
||||
selected_cid = str(config_id).strip()
|
||||
if not selected_cid:
|
||||
return runtime_cfg, False
|
||||
|
||||
db_row = _fetch_db_config_by_config_id(selected_cid)
|
||||
if db_row is None:
|
||||
return runtime_cfg, False
|
||||
|
||||
updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
|
||||
return updated_cfg, True
|
||||
except Exception as e:
|
||||
pass
|
||||
return runtime_cfg, False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 以下函数已注释:不再需要网络模式自动检测功能
|
||||
# ============================================================================
|
||||
|
||||
# def get_server_ip() -> str:
|
||||
# """
|
||||
# 获取当前服务器的IP地址
|
||||
#
|
||||
# Returns:
|
||||
# 服务器IP地址字符串
|
||||
# """
|
||||
# try:
|
||||
# # 方式1:从环境变量获取(优先)
|
||||
# server_ip = os.getenv('SERVER_IP')
|
||||
# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']:
|
||||
# return server_ip
|
||||
#
|
||||
# # 方式2:通过socket获取
|
||||
# hostname = socket.gethostname()
|
||||
# ip_address = socket.gethostbyname(hostname)
|
||||
#
|
||||
# # 如果是本地回环地址,尝试获取真实IP
|
||||
# if ip_address.startswith('127.'):
|
||||
# # 尝试连接外部地址来获取本机IP
|
||||
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# try:
|
||||
# s.connect(('8.8.8.8', 80))
|
||||
# ip_address = s.getsockname()[0]
|
||||
# finally:
|
||||
# s.close()
|
||||
#
|
||||
# return ip_address
|
||||
# except Exception as e:
|
||||
# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1")
|
||||
# return '127.0.0.1'
|
||||
|
||||
|
||||
# def auto_detect_network_mode() -> NetworkMode:
|
||||
# """
|
||||
# 自动检测网络模式(基于服务器IP)
|
||||
#
|
||||
# 规则:
|
||||
# - 如果服务器IP在内网IP列表中 → internal(内网)
|
||||
# - 其他IP → external(外网)
|
||||
#
|
||||
# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表(逗号分隔)
|
||||
#
|
||||
# Returns:
|
||||
# 'internal' 或 'external'
|
||||
# """
|
||||
# server_ip = get_server_ip()
|
||||
#
|
||||
# # 从环境变量获取内网IP列表(支持多个IP,逗号分隔)
|
||||
# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55')
|
||||
# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')]
|
||||
#
|
||||
# # 判断当前IP是否在内网IP列表中
|
||||
# if server_ip in internal_ips:
|
||||
# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于内网,使用 INTERNAL 配置")
|
||||
# return 'internal'
|
||||
# else:
|
||||
# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置")
|
||||
# return 'external'
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 环境变量覆写功能已废弃 - 不再使用
|
||||
# ============================================================================
|
||||
# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]:
|
||||
# """
|
||||
# 从环境变量覆写配置(已废弃)
|
||||
# """
|
||||
# return runtime_cfg
|
||||
|
||||
|
||||
def load_unified_config(
|
||||
project_root: str,
|
||||
config_id: Optional[int | str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
network_mode: NetworkMode = None,
|
||||
env_override_models: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一配置加载器 - 按优先级加载配置
|
||||
|
||||
配置加载优先级:
|
||||
1. PG数据库配置(最高优先级,通过 dbrun.json 中的 config_id 读取)
|
||||
2. runtime.json 默认配置(最低优先级)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
config_id: 配置ID(整数或字符串,可选,优先从 dbrun.json 读取)
|
||||
group_id: 组ID(可选)
|
||||
network_mode: 已废弃,保留参数仅为向后兼容
|
||||
env_override_models: 已废弃,保留参数仅为向后兼容
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 最终的运行时配置
|
||||
"""
|
||||
try:
|
||||
# 步骤 1: 加载 runtime.json 作为基础配置
|
||||
runtime_config_path = os.path.join(project_root, "runtime.json")
|
||||
try:
|
||||
with open(runtime_config_path, "r", encoding="utf-8") as f:
|
||||
runtime_cfg = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
runtime_cfg = {"selections": {}}
|
||||
|
||||
# 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级)
|
||||
if config_id:
|
||||
# 优先使用传入的 config_id
|
||||
db_row = _fetch_db_config_by_config_id(config_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id")
|
||||
pass
|
||||
elif group_id:
|
||||
# 其次使用 group_id
|
||||
db_row = _fetch_db_config_by_group_id(group_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id")
|
||||
pass
|
||||
else:
|
||||
# 尝试从 dbrun.json 读取
|
||||
dbrun_config_id = _load_dbrun_config_id(project_root)
|
||||
if dbrun_config_id:
|
||||
db_row = _fetch_db_config_by_config_id(dbrun_config_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id")
|
||||
pass
|
||||
else:
|
||||
dbrun_group_id = _load_dbrun_group_id(project_root)
|
||||
if dbrun_group_id:
|
||||
db_row = _fetch_db_config_by_group_id(dbrun_group_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id")
|
||||
pass
|
||||
return runtime_cfg
|
||||
|
||||
except Exception as e:
|
||||
return {"selections": {}}
|
||||
|
||||
|
||||
# 向后兼容的别名
|
||||
apply_runtime_overrides = apply_runtime_overrides_by_config
|
||||
11
api/app/core/memory/utils/embedder/__init__.py
Normal file
11
api/app/core/memory/utils/embedder/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Embedder utilities module."""
|
||||
|
||||
from app.core.memory.utils.embedder.embedder_utils import (
|
||||
get_embedder_client,
|
||||
get_embedder_client_from_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_embedder_client",
|
||||
"get_embedder_client_from_config",
|
||||
]
|
||||
81
api/app/core/memory/utils/embedder/embedder_utils.py
Normal file
81
api/app/core/memory/utils/embedder/embedder_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Embedder Client Utilities
|
||||
|
||||
This module provides centralized functions for creating embedder clients.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
def get_embedder_client_from_config(memory_config: "MemoryConfig") -> OpenAIEmbedderClient:
|
||||
"""
|
||||
Get embedder client from MemoryConfig object.
|
||||
|
||||
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
|
||||
This ensures proper configuration management and multi-tenant support.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient: Initialized embedder client
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding model ID is not configured or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> embedder_client = get_embedder_client_from_config(memory_config)
|
||||
"""
|
||||
if not memory_config.embedding_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no embedding model configured"
|
||||
)
|
||||
return get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
|
||||
def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
|
||||
"""
|
||||
Get embedder client by model ID.
|
||||
|
||||
**LEGACY/TEST METHOD**: Use this function only for:
|
||||
- Test/evaluation code where you have a model ID directly
|
||||
- Legacy code that hasn't been migrated to MemoryConfig yet
|
||||
|
||||
For production code with MemoryConfig, use get_embedder_client_from_config() instead.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID (required)
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient: Initialized embedder client
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding_id is not provided or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> # For tests/evaluations only
|
||||
>>> embedder_client = get_embedder_client("model-uuid-string")
|
||||
"""
|
||||
if not embedding_id:
|
||||
raise ValueError("Embedding ID is required but was not provided")
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
return embedder_client
|
||||
except Exception as e:
|
||||
model_name = embedder_config_dict.get('model_name', 'unknown')
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder client for model '{model_name}': {str(e)}"
|
||||
) from e
|
||||
@@ -1,67 +1,107 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
async def handle_response(response: type[BaseModel]) -> dict:
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
def get_llm_client(llm_id: str | None = None):
|
||||
llm_id = llm_id or config_defs.SELECTED_LLM_ID
|
||||
def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient:
|
||||
"""
|
||||
Get LLM client from MemoryConfig object.
|
||||
|
||||
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
|
||||
This ensures proper configuration management and multi-tenant support.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing llm_model_id
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized LLM client
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM model ID is not configured or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> llm_client = get_llm_client_from_config(memory_config)
|
||||
"""
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"Configuration {memory_config.config_id} has no LLM model configured"
|
||||
)
|
||||
return get_llm_client(str(memory_config.llm_model_id))
|
||||
|
||||
# Validate LLM ID exists before attempting to get config
|
||||
|
||||
def get_llm_client(llm_id: str):
|
||||
"""
|
||||
Get LLM client by model ID.
|
||||
|
||||
**LEGACY/TEST METHOD**: Use this function only for:
|
||||
- Test/evaluation code where you have a model ID directly
|
||||
- Legacy code that hasn't been migrated to MemoryConfig yet
|
||||
|
||||
For production code with MemoryConfig, use get_llm_client_from_config() instead.
|
||||
|
||||
Args:
|
||||
llm_id: LLM model ID (required)
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized LLM client
|
||||
|
||||
Raises:
|
||||
ValueError: If llm_id is not provided or client initialization fails
|
||||
|
||||
Example:
|
||||
>>> # For tests/evaluations only
|
||||
>>> llm_client = get_llm_client("model-uuid-string")
|
||||
"""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
# Re-raise with clear error message about invalid LLM ID
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
# 移除调试打印,避免污染终端输出
|
||||
# print(model_config)
|
||||
llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
# print(llm.dict())
|
||||
return llm_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str | None = None):
|
||||
def get_reranker_client(rerank_id: str):
|
||||
"""
|
||||
Get an LLM client configured for reranking.
|
||||
|
||||
Args:
|
||||
rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID.
|
||||
rerank_id: Reranker model ID (required)
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized client for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If rerank_id is invalid or client initialization fails
|
||||
ValueError: If rerank_id is not provided or client initialization fails
|
||||
"""
|
||||
rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID
|
||||
|
||||
# Validate rerank ID exists before attempting to get config
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
# Re-raise with clear error message about invalid rerank ID
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
|
||||
@@ -10,28 +10,29 @@
|
||||
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true"
|
||||
REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3")
|
||||
REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval")
|
||||
REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME")
|
||||
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
REFLEXION_ENABLED,
|
||||
REFLEXION_ITERATION_PERIOD,
|
||||
REFLEXION_RANGE,
|
||||
REFLEXION_BASELINE,
|
||||
)
|
||||
from app.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 并发限制(可通过环境变量覆盖)
|
||||
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
"""
|
||||
Validators for file upload system.
|
||||
Validators package for various validation utilities.
|
||||
"""
|
||||
from app.core.validators.file_validator import FileValidator, ValidationResult
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_llm_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
|
||||
__all__ = ["FileValidator", "ValidationResult"]
|
||||
__all__ = [
|
||||
# File validators
|
||||
"FileValidator",
|
||||
"ValidationResult",
|
||||
# Memory config validators
|
||||
"validate_model_exists_and_active",
|
||||
"validate_and_resolve_model_id",
|
||||
"validate_embedding_model",
|
||||
"validate_llm_model",
|
||||
]
|
||||
|
||||
250
api/app/core/validators/memory_config_validators.py
Normal file
250
api/app/core/validators/memory_config_validators.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Validators
|
||||
|
||||
This module provides validation functions for memory configuration models.
|
||||
|
||||
Functions:
|
||||
validate_model_exists_and_active: Validate model exists and is active
|
||||
validate_and_resolve_model_id: Validate and resolve model ID with DB lookup
|
||||
validate_embedding_model: Validate embedding model availability
|
||||
validate_llm_model: Validate LLM model availability
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.logging_config import get_config_logger
|
||||
from app.schemas.memory_config_schema import (
|
||||
InvalidConfigError,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_config_logger()
|
||||
|
||||
|
||||
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
|
||||
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||
"""Parse model ID from string or UUID."""
|
||||
if model_id is None:
|
||||
return None
|
||||
if isinstance(model_id, UUID):
|
||||
return model_id
|
||||
if isinstance(model_id, str):
|
||||
if not model_id.strip():
|
||||
return None
|
||||
try:
|
||||
return UUID(model_id.strip())
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid UUID format for {model_type} model ID: '{model_id}'",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for {model_type} model ID: expected str or UUID, got {type(model_id).__name__}",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
def validate_model_exists_and_active(
|
||||
model_id: UUID,
|
||||
model_type: str,
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate that a model exists and is active.
|
||||
|
||||
Args:
|
||||
model_id: Model UUID to validate
|
||||
model_type: Type of model ("llm", "embedding", "rerank")
|
||||
db: Database session
|
||||
tenant_id: Optional tenant ID for filtering
|
||||
config_id: Optional configuration ID for error context
|
||||
workspace_id: Optional workspace ID for error context
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_active)
|
||||
|
||||
Raises:
|
||||
ModelNotFoundError: If model does not exist
|
||||
ModelInactiveError: If model exists but is inactive
|
||||
"""
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not model:
|
||||
logger.warning(
|
||||
"Model not found",
|
||||
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
raise ModelNotFoundError(
|
||||
model_id=model_id,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} not found"
|
||||
)
|
||||
|
||||
if not model.is_active:
|
||||
logger.warning(
|
||||
"Model inactive",
|
||||
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
raise ModelInactiveError(
|
||||
model_id=model_id,
|
||||
model_name=model.name,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} ({model.name}) is inactive"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Model validation successful",
|
||||
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
return model.name, model.is_active
|
||||
|
||||
except (ModelNotFoundError, ModelInactiveError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Model validation failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def validate_and_resolve_model_id(
|
||||
model_id_str: Union[str, UUID, None],
|
||||
model_type: str,
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
required: bool = False,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[Optional[UUID], Optional[str]]:
|
||||
"""Validate and resolve a model ID, checking existence and active status.
|
||||
|
||||
Returns:
|
||||
Tuple of (validated_uuid, model_name) or (None, None) if not required and empty
|
||||
"""
|
||||
if model_id_str is None or (isinstance(model_id_str, str) and not model_id_str.strip()):
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model ID is required",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id_str,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return None, None
|
||||
|
||||
model_uuid = _parse_model_id(model_id_str, model_type, config_id, workspace_id)
|
||||
if model_uuid is None:
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model ID is required",
|
||||
field_name=f"{model_type}_model_id",
|
||||
invalid_value=model_id_str,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return None, None
|
||||
|
||||
model_name, _ = validate_model_exists_and_active(
|
||||
model_uuid, model_type, db, tenant_id, config_id, workspace_id
|
||||
)
|
||||
return model_uuid, model_name
|
||||
|
||||
|
||||
def validate_embedding_model(
|
||||
config_id: int,
|
||||
embedding_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> UUID:
|
||||
"""Validate that embedding model is available and return its UUID.
|
||||
|
||||
Raises:
|
||||
InvalidConfigError: If embedding_id is not provided or invalid
|
||||
ModelNotFoundError: If embedding model does not exist
|
||||
ModelInactiveError: If embedding model is inactive
|
||||
"""
|
||||
if embedding_id is None or (isinstance(embedding_id, str) and not embedding_id.strip()):
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no embedding model configured",
|
||||
field_name="embedding_model_id",
|
||||
invalid_value=embedding_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
embedding_uuid, _ = validate_and_resolve_model_id(
|
||||
embedding_id, "embedding", db, tenant_id, required=True,
|
||||
config_id=config_id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if embedding_uuid is None:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no embedding model configured",
|
||||
field_name="embedding_model_id",
|
||||
invalid_value=embedding_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return embedding_uuid
|
||||
|
||||
|
||||
def validate_llm_model(
|
||||
config_id: int,
|
||||
llm_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> UUID:
|
||||
"""Validate that LLM model is available and return its UUID.
|
||||
|
||||
Raises:
|
||||
InvalidConfigError: If llm_id is not provided or invalid
|
||||
ModelNotFoundError: If LLM model does not exist
|
||||
ModelInactiveError: If LLM model is inactive
|
||||
"""
|
||||
if llm_id is None or (isinstance(llm_id, str) and not llm_id.strip()):
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no LLM model configured",
|
||||
field_name="llm_model_id",
|
||||
invalid_value=llm_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
llm_uuid, _ = validate_and_resolve_model_id(
|
||||
llm_id, "llm", db, tenant_id, required=True,
|
||||
config_id=config_id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if llm_uuid is None:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration {config_id} has no LLM model configured",
|
||||
field_name="llm_model_id",
|
||||
invalid_value=llm_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return llm_uuid
|
||||
39
api/app/models/memory_config_model.py
Normal file
39
api/app/models/memory_config_model.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Model - Backward Compatibility
|
||||
|
||||
This module provides backward compatibility for imports.
|
||||
All classes have been moved to app.schemas.memory_config_schema.
|
||||
|
||||
DEPRECATED: Import from app.schemas.memory_config_schema instead.
|
||||
"""
|
||||
|
||||
# Re-export for backward compatibility
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
MemoryConfigValidation,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
ModelValidation,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceValidation,
|
||||
validate_memory_config_data,
|
||||
validate_model_data,
|
||||
validate_workspace_data,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfigurationError",
|
||||
"InvalidConfigError",
|
||||
"MemoryConfig",
|
||||
"MemoryConfigValidation",
|
||||
"ModelInactiveError",
|
||||
"ModelNotFoundError",
|
||||
"ModelValidation",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspaceValidation",
|
||||
"validate_memory_config_data",
|
||||
"validate_model_data",
|
||||
"validate_workspace_data",
|
||||
]
|
||||
@@ -8,24 +8,26 @@ Classes:
|
||||
DataConfigRepository: 数据配置仓储类,提供CRUD操作
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
# 获取配置专用日志器
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
class DataConfigRepository:
|
||||
@@ -443,7 +445,129 @@ class DataConfigRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
|
||||
"""Get data config and its associated workspace information
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
config_id: Configuration ID
|
||||
|
||||
Returns:
|
||||
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when config exists but workspace doesn't
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Log configuration loading start
|
||||
config_logger.info(
|
||||
"Loading configuration with workspace",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id
|
||||
}
|
||||
)
|
||||
|
||||
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
|
||||
|
||||
try:
|
||||
# Use join query to get both config and workspace
|
||||
result = db.query(DataConfig, Workspace).join(
|
||||
Workspace, DataConfig.workspace_id == Workspace.id
|
||||
).filter(DataConfig.config_id == config_id).first()
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not result:
|
||||
# Check if config exists but workspace is missing
|
||||
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if config_only:
|
||||
if config_only.workspace_id is None:
|
||||
config_logger.error(
|
||||
"Configuration has no associated workspace ID",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"workspace_id": None,
|
||||
"load_result": "no_workspace_id",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.error(f"Data config {config_id} has no associated workspace ID")
|
||||
raise ValueError(f"Configuration {config_id} has no associated workspace")
|
||||
else:
|
||||
config_logger.error(
|
||||
"Configuration references non-existent workspace",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"workspace_id": str(config_only.workspace_id),
|
||||
"load_result": "workspace_not_found",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
|
||||
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
||||
|
||||
config_logger.debug(
|
||||
"Configuration not found",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
db_logger.debug(f"Data config not found: config_id={config_id}")
|
||||
return None
|
||||
|
||||
config, workspace = result
|
||||
|
||||
# Log successful configuration loading
|
||||
config_logger.info(
|
||||
"Configuration with workspace loaded successfully",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(workspace.id),
|
||||
"workspace_name": workspace.name,
|
||||
"tenant_id": str(workspace.tenant_id),
|
||||
"load_result": "success",
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
|
||||
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
|
||||
except ValueError:
|
||||
# Re-raise known business exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load configuration with workspace",
|
||||
extra={
|
||||
"operation": "get_config_with_workspace",
|
||||
"config_id": config_id,
|
||||
"load_result": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"elapsed_ms": elapsed_ms
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
@@ -8,13 +8,13 @@ class UserInput(BaseModel):
|
||||
history: list[dict]
|
||||
search_switch: str
|
||||
group_id: str
|
||||
config_id: Optional[str] = None
|
||||
config_id: str
|
||||
|
||||
|
||||
class Write_UserInput(BaseModel):
|
||||
message: str
|
||||
group_id: str
|
||||
config_id: Optional[str] = None
|
||||
config_id: str
|
||||
|
||||
class End_User_Information(BaseModel):
|
||||
end_user_name: str # 这是要更新的用户名
|
||||
|
||||
451
api/app/schemas/memory_config_schema.py
Normal file
451
api/app/schemas/memory_config_schema.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Configuration Schemas
|
||||
|
||||
This module provides schema definitions for memory configuration.
|
||||
|
||||
Classes:
|
||||
MemoryConfig: Immutable memory configuration loaded from database
|
||||
MemoryConfigValidation: Pydantic model for configuration validation
|
||||
WorkspaceValidation: Pydantic model for workspace validation
|
||||
ModelValidation: Pydantic model for model configuration validation
|
||||
ConfigurationError: Base exception for configuration-related errors
|
||||
WorkspaceNotFoundError: Raised when workspace does not exist
|
||||
ModelNotFoundError: Raised when a required model does not exist
|
||||
ModelInactiveError: Raised when a required model exists but is inactive
|
||||
InvalidConfigError: Raised when configuration validation fails
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
|
||||
# ==================== Configuration Exception Classes ====================
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""Base exception for configuration-related errors.
|
||||
|
||||
This exception includes context information to help with debugging
|
||||
and provides detailed error messages for different failure scenarios.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize configuration error with context.
|
||||
|
||||
Args:
|
||||
message: Error message describing the failure
|
||||
config_id: Optional configuration ID for context
|
||||
workspace_id: Optional workspace ID for context
|
||||
context: Optional additional context information
|
||||
"""
|
||||
self.config_id = config_id
|
||||
self.workspace_id = workspace_id
|
||||
self.context = context or {}
|
||||
|
||||
# Build detailed error message with context
|
||||
detailed_message = message
|
||||
if config_id is not None:
|
||||
detailed_message = f"Configuration {config_id}: {message}"
|
||||
if workspace_id is not None:
|
||||
detailed_message = f"{detailed_message} (workspace: {workspace_id})"
|
||||
|
||||
# Add context information if available
|
||||
if self.context:
|
||||
context_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
|
||||
detailed_message = f"{detailed_message} [Context: {context_str}]"
|
||||
|
||||
super().__init__(detailed_message)
|
||||
|
||||
|
||||
class WorkspaceNotFoundError(ConfigurationError):
|
||||
"""Raised when workspace does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
config_id: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
message = f"Workspace {workspace_id} not found in database"
|
||||
|
||||
context = {"workspace_id": str(workspace_id)}
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
class ModelNotFoundError(ConfigurationError):
|
||||
"""Raised when a required model does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Union[str, UUID],
|
||||
model_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
message = f"{model_type.title()} model {model_id} not found in database"
|
||||
|
||||
context = {
|
||||
"model_id": str(model_id),
|
||||
"model_type": model_type,
|
||||
"failure_type": "not_found",
|
||||
}
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
class ModelInactiveError(ConfigurationError):
|
||||
"""Raised when a required model exists but is inactive."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Union[str, UUID],
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
if message is None:
|
||||
message = f"{model_type.title()} model {model_id} ({model_name}) is inactive"
|
||||
|
||||
context = {
|
||||
"model_id": str(model_id),
|
||||
"model_name": model_name,
|
||||
"model_type": model_type,
|
||||
"failure_type": "inactive",
|
||||
}
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
class InvalidConfigError(ConfigurationError):
|
||||
"""Raised when configuration validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
field_name: Optional[str] = None,
|
||||
invalid_value: Optional[Any] = None,
|
||||
config_id: Optional[int] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
):
|
||||
context = {}
|
||||
if field_name is not None:
|
||||
context["field_name"] = field_name
|
||||
if invalid_value is not None:
|
||||
context["invalid_value"] = str(invalid_value)
|
||||
context["invalid_value_type"] = type(invalid_value).__name__
|
||||
|
||||
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
|
||||
|
||||
|
||||
# ==================== Pydantic Validation Models ====================
|
||||
|
||||
|
||||
class MemoryConfigValidation(BaseModel):
|
||||
"""Pydantic model for validating memory configuration data from database."""
|
||||
|
||||
config_id: int = Field(..., gt=0, description="Configuration ID must be positive")
|
||||
config_name: str = Field(..., min_length=1, max_length=255)
|
||||
workspace_id: UUID = Field(..., description="Workspace UUID")
|
||||
workspace_name: str = Field(..., min_length=1, max_length=255)
|
||||
tenant_id: UUID = Field(..., description="Tenant UUID")
|
||||
|
||||
embedding_model_id: UUID = Field(..., description="Embedding model UUID (required)")
|
||||
embedding_model_name: str = Field(..., min_length=1, max_length=255)
|
||||
llm_model_id: UUID = Field(..., description="LLM model UUID (required)")
|
||||
llm_model_name: str = Field(..., min_length=1, max_length=255)
|
||||
rerank_model_id: Optional[UUID] = Field(None, description="Rerank model UUID (optional)")
|
||||
rerank_model_name: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
storage_type: str = Field(..., min_length=1, max_length=50)
|
||||
|
||||
chunker_strategy: str = Field(default="RecursiveChunker", min_length=1, max_length=100)
|
||||
reflexion_enabled: bool = Field(default=False)
|
||||
reflexion_iteration_period: int = Field(default=3, ge=1, le=100)
|
||||
reflexion_range: Literal["retrieval", "all"] = Field(default="retrieval")
|
||||
reflexion_baseline: Literal["time", "fact", "time_and_fact"] = Field(default="time")
|
||||
|
||||
llm_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
config_version: str = Field(default="2.0", min_length=1, max_length=10)
|
||||
|
||||
@field_validator("config_name", "workspace_name", "embedding_model_name", "llm_model_name")
|
||||
@classmethod
|
||||
def validate_non_empty_strings(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Field cannot be empty or whitespace-only")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("storage_type")
|
||||
@classmethod
|
||||
def validate_storage_type(cls, v):
|
||||
valid_types = ["neo4j", "elasticsearch", "qdrant", "milvus", "chroma"]
|
||||
if v.lower() not in valid_types:
|
||||
raise ValueError(f"Storage type must be one of: {valid_types}")
|
||||
return v.lower()
|
||||
|
||||
@field_validator("llm_params", "embedding_params")
|
||||
@classmethod
|
||||
def validate_model_params(cls, v):
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("Model parameters must be a dictionary")
|
||||
reserved_keys = ["model_id", "model_name", "api_key", "base_url"]
|
||||
for key in v.keys():
|
||||
if key in reserved_keys:
|
||||
raise ValueError(f"Model parameters cannot contain reserved parameter '{key}'")
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
|
||||
|
||||
class WorkspaceValidation(BaseModel):
|
||||
"""Pydantic model for validating workspace data from database."""
|
||||
|
||||
id: UUID = Field(..., description="Workspace UUID")
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
tenant_id: UUID = Field(..., description="Tenant UUID")
|
||||
storage_type: Optional[str] = Field(None, max_length=50)
|
||||
llm: Optional[str] = Field(None)
|
||||
embedding: Optional[str] = Field(None)
|
||||
rerank: Optional[str] = Field(None)
|
||||
is_active: bool = Field(default=True)
|
||||
|
||||
@field_validator("llm", "embedding", "rerank")
|
||||
@classmethod
|
||||
def validate_model_ids(cls, v):
|
||||
if v is None or v == "":
|
||||
return None
|
||||
try:
|
||||
UUID(v.strip())
|
||||
except ValueError:
|
||||
raise ValueError("Model ID must be a valid UUID string")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("is_active")
|
||||
@classmethod
|
||||
def validate_active_status(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Workspace must be active for configuration loading")
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
|
||||
|
||||
class ModelValidation(BaseModel):
|
||||
"""Pydantic model for validating model configuration data."""
|
||||
|
||||
id: UUID = Field(..., description="Model UUID")
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
type: str = Field(..., description="Model type (llm, embedding, rerank)")
|
||||
tenant_id: UUID = Field(..., description="Tenant UUID")
|
||||
is_active: bool = Field(..., description="Whether model is active")
|
||||
is_public: bool = Field(default=False)
|
||||
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
valid_types = ["llm", "embedding", "rerank"]
|
||||
if v.lower() not in valid_types:
|
||||
raise ValueError(f"Model type must be one of: {valid_types}")
|
||||
return v.lower()
|
||||
|
||||
@field_validator("is_active")
|
||||
@classmethod
|
||||
def validate_active_status(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model must be active for configuration use")
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
|
||||
|
||||
# ==================== Validation Helper Functions ====================
|
||||
|
||||
|
||||
def validate_memory_config_data(
|
||||
config_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
) -> MemoryConfigValidation:
|
||||
"""Validate memory configuration data using Pydantic model."""
|
||||
try:
|
||||
return MemoryConfigValidation(**config_data)
|
||||
except ValidationError as e:
|
||||
error_messages = []
|
||||
for error in e.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_messages.append(f"Field '{field_path}': {error['msg']}")
|
||||
|
||||
detailed_message = "Configuration validation failed:\n" + "\n".join(
|
||||
f" - {msg}" for msg in error_messages
|
||||
)
|
||||
|
||||
first_error = e.errors()[0] if e.errors() else {}
|
||||
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
|
||||
|
||||
raise InvalidConfigError(
|
||||
detailed_message,
|
||||
field_name=first_field or None,
|
||||
invalid_value=first_error.get("input"),
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
|
||||
def validate_workspace_data(
|
||||
workspace_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
) -> WorkspaceValidation:
|
||||
"""Validate workspace data using Pydantic model."""
|
||||
try:
|
||||
return WorkspaceValidation(**workspace_data)
|
||||
except ValidationError as e:
|
||||
error_messages = []
|
||||
for error in e.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_messages.append(f"Field '{field_path}': {error['msg']}")
|
||||
|
||||
detailed_message = "Workspace validation failed:\n" + "\n".join(
|
||||
f" - {msg}" for msg in error_messages
|
||||
)
|
||||
|
||||
first_error = e.errors()[0] if e.errors() else {}
|
||||
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
|
||||
workspace_id = workspace_data.get("id") if isinstance(workspace_data, dict) else None
|
||||
|
||||
raise InvalidConfigError(
|
||||
detailed_message,
|
||||
field_name=first_field or None,
|
||||
invalid_value=first_error.get("input"),
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
|
||||
def validate_model_data(
|
||||
model_data: Dict[str, Any], config_id: Optional[int] = None
|
||||
) -> ModelValidation:
|
||||
"""Validate model data using Pydantic model."""
|
||||
try:
|
||||
return ModelValidation(**model_data)
|
||||
except ValidationError as e:
|
||||
error_messages = []
|
||||
for error in e.errors():
|
||||
field_path = " -> ".join(str(loc) for loc in error["loc"])
|
||||
error_messages.append(f"Field '{field_path}': {error['msg']}")
|
||||
|
||||
detailed_message = "Model validation failed:\n" + "\n".join(
|
||||
f" - {msg}" for msg in error_messages
|
||||
)
|
||||
|
||||
first_error = e.errors()[0] if e.errors() else {}
|
||||
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
|
||||
|
||||
raise InvalidConfigError(
|
||||
detailed_message,
|
||||
field_name=first_field or None,
|
||||
invalid_value=first_error.get("input"),
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Immutable Configuration Data Structure ====================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryConfig:
|
||||
"""Immutable memory configuration loaded from database."""
|
||||
|
||||
config_id: int
|
||||
config_name: str
|
||||
workspace_id: UUID
|
||||
workspace_name: str
|
||||
tenant_id: UUID
|
||||
|
||||
embedding_model_id: UUID
|
||||
embedding_model_name: str
|
||||
llm_model_id: UUID
|
||||
llm_model_name: str
|
||||
|
||||
storage_type: str
|
||||
|
||||
chunker_strategy: str
|
||||
reflexion_enabled: bool
|
||||
reflexion_iteration_period: int
|
||||
reflexion_range: str
|
||||
reflexion_baseline: str
|
||||
|
||||
loaded_at: datetime
|
||||
|
||||
rerank_model_id: Optional[UUID] = None
|
||||
rerank_model_name: Optional[str] = None
|
||||
|
||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||
config_version: str = "2.0"
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.config_name or not self.config_name.strip():
|
||||
raise InvalidConfigError("Configuration name cannot be empty")
|
||||
|
||||
if not self.embedding_model_id:
|
||||
raise InvalidConfigError("Embedding model ID is required")
|
||||
|
||||
if not self.llm_model_id:
|
||||
raise InvalidConfigError("LLM model ID is required")
|
||||
|
||||
@classmethod
|
||||
def from_validated_data(
|
||||
cls, validated_config: MemoryConfigValidation, loaded_at: datetime
|
||||
) -> "MemoryConfig":
|
||||
"""Create MemoryConfig from validated Pydantic data."""
|
||||
return cls(
|
||||
config_id=validated_config.config_id,
|
||||
config_name=validated_config.config_name,
|
||||
workspace_id=validated_config.workspace_id,
|
||||
workspace_name=validated_config.workspace_name,
|
||||
tenant_id=validated_config.tenant_id,
|
||||
embedding_model_id=validated_config.embedding_model_id,
|
||||
embedding_model_name=validated_config.embedding_model_name,
|
||||
storage_type=validated_config.storage_type,
|
||||
chunker_strategy=validated_config.chunker_strategy,
|
||||
reflexion_enabled=validated_config.reflexion_enabled,
|
||||
reflexion_iteration_period=validated_config.reflexion_iteration_period,
|
||||
reflexion_range=validated_config.reflexion_range,
|
||||
reflexion_baseline=validated_config.reflexion_baseline,
|
||||
loaded_at=loaded_at,
|
||||
llm_model_id=validated_config.llm_model_id,
|
||||
llm_model_name=validated_config.llm_model_name,
|
||||
rerank_model_id=validated_config.rerank_model_id,
|
||||
rerank_model_name=validated_config.rerank_model_name,
|
||||
llm_params=validated_config.llm_params,
|
||||
embedding_params=validated_config.embedding_params,
|
||||
config_version=validated_config.config_version,
|
||||
)
|
||||
|
||||
def get_model_summary(self) -> Dict[str, Optional[str]]:
|
||||
"""Get a summary of configured models."""
|
||||
return {
|
||||
"llm": self.llm_model_name,
|
||||
"embedding": self.embedding_model_name,
|
||||
"rerank": self.rerank_model_name,
|
||||
}
|
||||
|
||||
def is_model_configured(self, model_type: str) -> bool:
|
||||
"""Check if a specific model type is configured."""
|
||||
if model_type == "llm":
|
||||
return True
|
||||
elif model_type == "embedding":
|
||||
return True
|
||||
elif model_type == "rerank":
|
||||
return self.rerank_model_id is not None
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
@@ -4,45 +4,44 @@ Memory Agent Service
|
||||
Handles business logic for memory agent operations including read/write services,
|
||||
health checks, and message type classification.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.db import get_db
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
|
||||
from app.db import get_db
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Initialize Neo4j connector for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
@@ -56,6 +55,27 @@ class MemoryAgentService:
|
||||
self.user_locks: Dict[str, Lock] = {}
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def load_memory_config(self, config_id: int) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
countext = re.findall(r'"status": "(.*?)",', messages)[0]
|
||||
@@ -277,19 +297,20 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 如果 config_id 为 None,使用默认值 "17"
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
@@ -300,20 +321,43 @@ class MemoryAgentService:
|
||||
async with client.session("data_flow") as session:
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
tools = await load_mcp_tools(session)
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
|
||||
logger.debug("Write graph created successfully")
|
||||
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": message, "config_id": config_id},
|
||||
{"messages": message, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.error(f"Write workflow failed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_details
|
||||
)
|
||||
|
||||
raise ValueError(f"Write workflow failed: {error_details}")
|
||||
|
||||
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
@@ -365,15 +409,15 @@ class MemoryAgentService:
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -387,8 +431,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
@@ -404,45 +446,52 @@ class MemoryAgentService:
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "config_id": config_id},
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Debug: log message type and content preview
|
||||
msg_type = msg.__class__.__name__
|
||||
content_preview = str(msg_content)[:200] if msg_content else "empty"
|
||||
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(msg_content, str):
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(msg_content)
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Debug: log what keys are in parsed
|
||||
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
|
||||
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
@@ -450,34 +499,14 @@ class MemoryAgentService:
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
elif isinstance(msg_content, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in msg_content:
|
||||
intermediate_data = msg_content['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in msg_content:
|
||||
for intermediate_data in msg_content['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
@@ -489,18 +518,57 @@ class MemoryAgentService:
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
try:
|
||||
message = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(message, dict) and message.get('status') != '':
|
||||
summary_result = message.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
if audit_logger:
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
@@ -612,19 +680,25 @@ class MemoryAgentService:
|
||||
else:
|
||||
return output
|
||||
|
||||
async def classify_message_type(self, message: str) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
message: User message to classify
|
||||
config_id: Configuration ID to load LLM model from database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
logger.info("Classifying message type")
|
||||
|
||||
status = await status_typle(message)
|
||||
# Load configuration to get LLM model ID
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
@@ -790,7 +864,8 @@ class MemoryAgentService:
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -801,6 +876,7 @@ class MemoryAgentService:
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选)
|
||||
- current_user_id: 当前登录用户的ID(保留参数)
|
||||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||||
|
||||
返回格式:
|
||||
{
|
||||
@@ -862,15 +938,17 @@ class MemoryAgentService:
|
||||
|
||||
await connector.close()
|
||||
|
||||
if not statements:
|
||||
if not statements or not llm_id:
|
||||
result["tags"] = []
|
||||
if not llm_id and statements:
|
||||
logger.warning("llm_id not provided, skipping tag generation")
|
||||
else:
|
||||
# 构建摘要文本
|
||||
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
|
||||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||||
|
||||
# 使用LLM提取标签
|
||||
llm_client = get_llm_client()
|
||||
llm_client = get_llm_client(llm_id)
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
|
||||
264
api/app/services/memory_config_service.py
Normal file
264
api/app/services/memory_config_service.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
def _validate_config_id(config_id):
|
||||
"""Validate configuration ID format."""
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {config_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
try:
|
||||
parsed_id = int(config_id.strip())
|
||||
if parsed_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {parsed_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError as e:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_memory_config(
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method manages its own database session internally.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
from app.db import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
return MemoryConfigService._load_memory_config_with_db(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
service_name=service_name,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@staticmethod
|
||||
def _load_memory_config_with_db(
|
||||
config_id: int,
|
||||
db: Session,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""Internal method that loads memory configuration with an existing db session."""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
"Starting memory configuration loading",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
"Configuration not found in database",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"config_id": validated_config_id,
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"service": service_name,
|
||||
},
|
||||
)
|
||||
raise ConfigurationError(
|
||||
f"Configuration {validated_config_id} not found in database"
|
||||
)
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
# Validate embedding model
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
|
||||
# Resolve LLM model
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Resolve optional rerank model
|
||||
rerank_uuid = None
|
||||
rerank_name = None
|
||||
if memory_config.rerank_id:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Get embedding model name
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
config_name=memory_config.config_name,
|
||||
workspace_id=workspace.id,
|
||||
workspace_name=workspace.name,
|
||||
tenant_id=workspace.tenant_id,
|
||||
llm_model_id=llm_uuid,
|
||||
llm_model_name=llm_name,
|
||||
embedding_model_id=embedding_uuid,
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.info(
|
||||
"Memory configuration loaded successfully",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": validated_config_id,
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(config.workspace_id),
|
||||
"load_result": "success",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load memory configuration",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
"load_result": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.error(f"Failed to load memory configuration {config_id}: {e}")
|
||||
if isinstance(e, (ConfigurationError, ValueError)):
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
@@ -4,39 +4,40 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigPilotRun,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Load environment variables for Neo4j connector
|
||||
load_dotenv()
|
||||
@@ -48,6 +49,27 @@ class MemoryStorageService:
|
||||
|
||||
def __init__(self):
|
||||
logger.info("MemoryStorageService initialized")
|
||||
|
||||
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryStorageService"
|
||||
)
|
||||
|
||||
async def get_storage_info(self) -> dict:
|
||||
"""
|
||||
@@ -248,7 +270,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
@@ -257,24 +278,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
# 步骤 1: 配置加载和验证(数据库优先)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
raise ValueError("未提供 payload.config_id,禁止启动试运行")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
@@ -282,12 +291,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
memory_config = MemoryConfigService.load_memory_config(
|
||||
config_id=int(cid),
|
||||
service_name="MemoryStorageService.pilot_run_stream"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
raise RuntimeError(f"Configuration loading failed: {e}")
|
||||
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
|
||||
@@ -1,28 +1,30 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
import redis
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.db import get_db
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.core.config import settings
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
import redis
|
||||
import requests
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
@@ -221,11 +223,17 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
else:
|
||||
detailed_error = str(e)
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"error": detailed_error,
|
||||
"group_id": group_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -283,11 +291,17 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
else:
|
||||
detailed_error = str(e)
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"error": detailed_error,
|
||||
"group_id": group_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -300,9 +314,10 @@ def reflection_engine() -> None:
|
||||
|
||||
Intentionally left blank; replace with real reflection logic later.
|
||||
"""
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
import asyncio
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
|
||||
@@ -377,10 +392,10 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.services.memory_storage_service import search_all
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user